refactor(mpc-system): migrate to party-driven architecture with PartyID-based routing
- Remove Address field from PartyEndpoint (parties connect to router themselves) - Update K8s Discovery to only manage PartyID and Role labels - Add Party registration and SessionEvent protobuf definitions - Implement PartyRegistry and SessionEventBroadcaster domain logic - Add RegisterParty and SubscribeSessionEvents gRPC handlers - Prepare infrastructure for party-driven MPC coordination This is the first phase of migrating from coordinator-driven to party-driven architecture following international MPC system design patterns.
This commit is contained in:
parent
e975e9d86c
commit
747e4ae8ef
|
|
@ -1,28 +1,28 @@
|
||||||
{
|
{
|
||||||
"permissions": {
|
"permissions": {
|
||||||
"allow": [
|
"allow": [
|
||||||
"Bash(dir:*)",
|
"Bash(dir:*)",
|
||||||
"Bash(tree:*)",
|
"Bash(tree:*)",
|
||||||
"Bash(find:*)",
|
"Bash(find:*)",
|
||||||
"Bash(ls -la \"c:\\Users\\dong\\Desktop\\rwadurian\\backend\\services\"\" 2>/dev/null || dir \"c:UsersdongDesktoprwadurianbackendservices\"\")",
|
"Bash(ls -la \"c:\\Users\\dong\\Desktop\\rwadurian\\backend\\services\"\" 2>/dev/null || dir \"c:UsersdongDesktoprwadurianbackendservices\"\")",
|
||||||
"Bash(mkdir:*)",
|
"Bash(mkdir:*)",
|
||||||
"Bash(npm run build:*)",
|
"Bash(npm run build:*)",
|
||||||
"Bash(npx nest build)",
|
"Bash(npx nest build)",
|
||||||
"Bash(npm install)",
|
"Bash(npm install)",
|
||||||
"Bash(npx prisma migrate dev:*)",
|
"Bash(npx prisma migrate dev:*)",
|
||||||
"Bash(npx jest:*)",
|
"Bash(npx jest:*)",
|
||||||
"Bash(flutter test:*)",
|
"Bash(flutter test:*)",
|
||||||
"Bash(flutter analyze:*)",
|
"Bash(flutter analyze:*)",
|
||||||
"Bash(findstr:*)",
|
"Bash(findstr:*)",
|
||||||
"Bash(flutter pub get:*)",
|
"Bash(flutter pub get:*)",
|
||||||
"Bash(cat:*)",
|
"Bash(cat:*)",
|
||||||
"Bash(git add:*)",
|
"Bash(git add:*)",
|
||||||
"Bash(git commit -m \"$(cat <<''EOF''\nrefactor(infra): 统一微服务基础设施为共享模式\n\n- 将 presence-service 添加到主 docker-compose.yml(端口 3011,Redis DB 10)\n- 更新 init-databases.sh 添加 rwa_admin 和 rwa_presence 数据库\n- 重构 admin-service/deploy.sh 使用共享基础设施\n- 重构 presence-service/deploy.sh 使用共享基础设施\n- 添加 authorization-service 开发指南文档\n\n解决多个微服务独立启动重复基础设施(PostgreSQL/Redis/Kafka)的问题\n\n🤖 Generated with [Claude Code](https://claude.com/claude-code)\n\nCo-Authored-By: Claude <noreply@anthropic.com>\nEOF\n)\")",
|
"Bash(git commit -m \"$(cat <<''EOF''\nrefactor(infra): 统一微服务基础设施为共享模式\n\n- 将 presence-service 添加到主 docker-compose.yml(端口 3011,Redis DB 10)\n- 更新 init-databases.sh 添加 rwa_admin 和 rwa_presence 数据库\n- 重构 admin-service/deploy.sh 使用共享基础设施\n- 重构 presence-service/deploy.sh 使用共享基础设施\n- 添加 authorization-service 开发指南文档\n\n解决多个微服务独立启动重复基础设施(PostgreSQL/Redis/Kafka)的问题\n\n🤖 Generated with [Claude Code](https://claude.com/claude-code)\n\nCo-Authored-By: Claude <noreply@anthropic.com>\nEOF\n)\")",
|
||||||
"Bash(git push)",
|
"Bash(git push)",
|
||||||
"Bash(git commit -m \"$(cat <<''EOF''\nfeat(admin-service): 增强移动端版本上传功能\n\n- 添加 APK/IPA 文件解析器自动提取版本信息\n- 支持从安装包自动读取 versionName 和 versionCode\n- 添加 adbkit-apkreader 依赖解析 APK 文件\n- 添加 plist 依赖解析 IPA 文件\n- 优化上传接口支持自动填充版本信息\n\n🤖 Generated with [Claude Code](https://claude.com/claude-code)\n\nCo-Authored-By: Claude <noreply@anthropic.com>\nEOF\n)\")",
|
"Bash(git commit -m \"$(cat <<''EOF''\nfeat(admin-service): 增强移动端版本上传功能\n\n- 添加 APK/IPA 文件解析器自动提取版本信息\n- 支持从安装包自动读取 versionName 和 versionCode\n- 添加 adbkit-apkreader 依赖解析 APK 文件\n- 添加 plist 依赖解析 IPA 文件\n- 优化上传接口支持自动填充版本信息\n\n🤖 Generated with [Claude Code](https://claude.com/claude-code)\n\nCo-Authored-By: Claude <noreply@anthropic.com>\nEOF\n)\")",
|
||||||
"Bash(git commit:*)"
|
"Bash(git commit:*)"
|
||||||
],
|
],
|
||||||
"deny": [],
|
"deny": [],
|
||||||
"ask": []
|
"ask": []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,14 @@
|
||||||
"Bash(wsl.exe -- bash -c 'find ~/rwadurian/backend/mpc-system/services/server-party -name \"\"main.go\"\" -path \"\"*/cmd/server/*\"\"')",
|
"Bash(wsl.exe -- bash -c 'find ~/rwadurian/backend/mpc-system/services/server-party -name \"\"main.go\"\" -path \"\"*/cmd/server/*\"\"')",
|
||||||
"Bash(wsl.exe -- bash -c 'cat ~/rwadurian/backend/mpc-system/services/server-party/cmd/server/main.go | grep -E \"\"grpc|GRPC|gRPC|50051\"\" | head -20')",
|
"Bash(wsl.exe -- bash -c 'cat ~/rwadurian/backend/mpc-system/services/server-party/cmd/server/main.go | grep -E \"\"grpc|GRPC|gRPC|50051\"\" | head -20')",
|
||||||
"Bash(wsl.exe -- bash:*)",
|
"Bash(wsl.exe -- bash:*)",
|
||||||
"Bash(dir:*)"
|
"Bash(dir:*)",
|
||||||
|
"Bash(go version:*)",
|
||||||
|
"Bash(go mod download:*)",
|
||||||
|
"Bash(go build:*)",
|
||||||
|
"Bash(go mod tidy:*)",
|
||||||
|
"Bash(findstr:*)",
|
||||||
|
"Bash(del \"c:\\Users\\dong\\Desktop\\rwadurian\\backend\\mpc-system\\PARTY_ROLE_VERIFICATION_REPORT.md\")",
|
||||||
|
"Bash(protoc:*)"
|
||||||
],
|
],
|
||||||
"deny": [],
|
"deny": [],
|
||||||
"ask": []
|
"ask": []
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,378 +1,378 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# RWADurian API Gateway (Kong) - 部署脚本
|
# RWADurian API Gateway (Kong) - 部署脚本
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Usage:
|
# Usage:
|
||||||
# ./deploy.sh up # 启动网关
|
# ./deploy.sh up # 启动网关
|
||||||
# ./deploy.sh down # 停止网关
|
# ./deploy.sh down # 停止网关
|
||||||
# ./deploy.sh restart # 重启网关
|
# ./deploy.sh restart # 重启网关
|
||||||
# ./deploy.sh logs # 查看日志
|
# ./deploy.sh logs # 查看日志
|
||||||
# ./deploy.sh status # 查看状态
|
# ./deploy.sh status # 查看状态
|
||||||
# ./deploy.sh health # 健康检查
|
# ./deploy.sh health # 健康检查
|
||||||
# ./deploy.sh reload # 重载 Kong 配置
|
# ./deploy.sh reload # 重载 Kong 配置
|
||||||
# ./deploy.sh routes # 查看所有路由
|
# ./deploy.sh routes # 查看所有路由
|
||||||
# ./deploy.sh monitoring # 启动监控栈 (Prometheus + Grafana)
|
# ./deploy.sh monitoring # 启动监控栈 (Prometheus + Grafana)
|
||||||
# ./deploy.sh metrics # 查看 Prometheus 指标
|
# ./deploy.sh metrics # 查看 Prometheus 指标
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
# 颜色定义
|
# 颜色定义
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
GREEN='\033[0;32m'
|
GREEN='\033[0;32m'
|
||||||
YELLOW='\033[1;33m'
|
YELLOW='\033[1;33m'
|
||||||
BLUE='\033[0;34m'
|
BLUE='\033[0;34m'
|
||||||
NC='\033[0m'
|
NC='\033[0m'
|
||||||
|
|
||||||
# 日志函数
|
# 日志函数
|
||||||
log_info() { echo -e "${BLUE}[INFO]${NC} $1"; }
|
log_info() { echo -e "${BLUE}[INFO]${NC} $1"; }
|
||||||
log_success() { echo -e "${GREEN}[SUCCESS]${NC} $1"; }
|
log_success() { echo -e "${GREEN}[SUCCESS]${NC} $1"; }
|
||||||
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
||||||
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
||||||
|
|
||||||
# 项目信息
|
# 项目信息
|
||||||
PROJECT_NAME="rwa-api-gateway"
|
PROJECT_NAME="rwa-api-gateway"
|
||||||
KONG_ADMIN_URL="http://localhost:8001"
|
KONG_ADMIN_URL="http://localhost:8001"
|
||||||
KONG_PROXY_URL="http://localhost:8000"
|
KONG_PROXY_URL="http://localhost:8000"
|
||||||
|
|
||||||
# 脚本目录
|
# 脚本目录
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
|
||||||
# 切换到脚本所在目录
|
# 切换到脚本所在目录
|
||||||
cd "$SCRIPT_DIR"
|
cd "$SCRIPT_DIR"
|
||||||
|
|
||||||
# 加载环境变量
|
# 加载环境变量
|
||||||
if [ -f ".env" ]; then
|
if [ -f ".env" ]; then
|
||||||
log_info "Loading environment from .env file"
|
log_info "Loading environment from .env file"
|
||||||
set -a
|
set -a
|
||||||
source .env
|
source .env
|
||||||
set +a
|
set +a
|
||||||
elif [ -f ".env.example" ]; then
|
elif [ -f ".env.example" ]; then
|
||||||
log_warn ".env file not found!"
|
log_warn ".env file not found!"
|
||||||
log_warn "Creating .env from .env.example..."
|
log_warn "Creating .env from .env.example..."
|
||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
log_error "Please edit .env file to configure your environment, then run again"
|
log_error "Please edit .env file to configure your environment, then run again"
|
||||||
exit 1
|
exit 1
|
||||||
else
|
else
|
||||||
log_error "Neither .env nor .env.example found!"
|
log_error "Neither .env nor .env.example found!"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# 检查 Docker
|
# 检查 Docker
|
||||||
check_docker() {
|
check_docker() {
|
||||||
if ! command -v docker &> /dev/null; then
|
if ! command -v docker &> /dev/null; then
|
||||||
log_error "Docker 未安装"
|
log_error "Docker 未安装"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
if ! docker info &> /dev/null; then
|
if ! docker info &> /dev/null; then
|
||||||
log_error "Docker 服务未运行"
|
log_error "Docker 服务未运行"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
# 检查 Docker Compose
|
# 检查 Docker Compose
|
||||||
check_docker_compose() {
|
check_docker_compose() {
|
||||||
if docker compose version &> /dev/null; then
|
if docker compose version &> /dev/null; then
|
||||||
COMPOSE_CMD="docker compose"
|
COMPOSE_CMD="docker compose"
|
||||||
elif command -v docker-compose &> /dev/null; then
|
elif command -v docker-compose &> /dev/null; then
|
||||||
COMPOSE_CMD="docker-compose"
|
COMPOSE_CMD="docker-compose"
|
||||||
else
|
else
|
||||||
log_error "Docker Compose 未安装"
|
log_error "Docker Compose 未安装"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
# 检查后端服务连通性(可选)
|
# 检查后端服务连通性(可选)
|
||||||
check_backend() {
|
check_backend() {
|
||||||
local BACKEND_IP="${BACKEND_SERVER_IP:-192.168.1.111}"
|
local BACKEND_IP="${BACKEND_SERVER_IP:-192.168.1.111}"
|
||||||
log_info "检查后端服务器 $BACKEND_IP 连通性..."
|
log_info "检查后端服务器 $BACKEND_IP 连通性..."
|
||||||
if ping -c 1 -W 2 $BACKEND_IP &> /dev/null; then
|
if ping -c 1 -W 2 $BACKEND_IP &> /dev/null; then
|
||||||
log_success "后端服务器可达"
|
log_success "后端服务器可达"
|
||||||
else
|
else
|
||||||
log_warn "无法 ping 通后端服务器 $BACKEND_IP"
|
log_warn "无法 ping 通后端服务器 $BACKEND_IP"
|
||||||
log_warn "请确保后端服务已启动且网络可达"
|
log_warn "请确保后端服务已启动且网络可达"
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
# 启动服务
|
# 启动服务
|
||||||
cmd_up() {
|
cmd_up() {
|
||||||
log_info "启动 Kong API Gateway..."
|
log_info "启动 Kong API Gateway..."
|
||||||
check_backend
|
check_backend
|
||||||
$COMPOSE_CMD up -d
|
$COMPOSE_CMD up -d
|
||||||
|
|
||||||
log_info "等待 Kong 启动..."
|
log_info "等待 Kong 启动..."
|
||||||
sleep 10
|
sleep 10
|
||||||
|
|
||||||
# 检查状态
|
# 检查状态
|
||||||
if docker ps | grep -q rwa-kong; then
|
if docker ps | grep -q rwa-kong; then
|
||||||
log_success "Kong API Gateway 启动成功!"
|
log_success "Kong API Gateway 启动成功!"
|
||||||
echo ""
|
echo ""
|
||||||
echo "服务地址:"
|
echo "服务地址:"
|
||||||
echo " Proxy: http://localhost:8000"
|
echo " Proxy: http://localhost:8000"
|
||||||
echo " Admin API: http://localhost:8001"
|
echo " Admin API: http://localhost:8001"
|
||||||
echo " Admin GUI: http://localhost:8002"
|
echo " Admin GUI: http://localhost:8002"
|
||||||
echo ""
|
echo ""
|
||||||
echo "查看路由: ./deploy.sh routes"
|
echo "查看路由: ./deploy.sh routes"
|
||||||
else
|
else
|
||||||
log_error "Kong 启动失败,查看日志: ./deploy.sh logs"
|
log_error "Kong 启动失败,查看日志: ./deploy.sh logs"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
# 停止服务
|
# 停止服务
|
||||||
cmd_down() {
|
cmd_down() {
|
||||||
log_info "停止 Kong API Gateway..."
|
log_info "停止 Kong API Gateway..."
|
||||||
$COMPOSE_CMD down
|
$COMPOSE_CMD down
|
||||||
log_success "Kong 已停止"
|
log_success "Kong 已停止"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 重启服务
|
# 重启服务
|
||||||
cmd_restart() {
|
cmd_restart() {
|
||||||
log_info "重启 Kong API Gateway..."
|
log_info "重启 Kong API Gateway..."
|
||||||
$COMPOSE_CMD restart
|
$COMPOSE_CMD restart
|
||||||
log_success "Kong 已重启"
|
log_success "Kong 已重启"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 查看日志
|
# 查看日志
|
||||||
cmd_logs() {
|
cmd_logs() {
|
||||||
$COMPOSE_CMD logs -f
|
$COMPOSE_CMD logs -f
|
||||||
}
|
}
|
||||||
|
|
||||||
# 查看状态
|
# 查看状态
|
||||||
cmd_status() {
|
cmd_status() {
|
||||||
log_info "Kong API Gateway 状态:"
|
log_info "Kong API Gateway 状态:"
|
||||||
$COMPOSE_CMD ps
|
$COMPOSE_CMD ps
|
||||||
}
|
}
|
||||||
|
|
||||||
# 健康检查
|
# 健康检查
|
||||||
cmd_health() {
|
cmd_health() {
|
||||||
log_info "Kong 健康检查..."
|
log_info "Kong 健康检查..."
|
||||||
|
|
||||||
# 检查 Kong 状态
|
# 检查 Kong 状态
|
||||||
response=$(curl -s $KONG_ADMIN_URL/status 2>/dev/null)
|
response=$(curl -s $KONG_ADMIN_URL/status 2>/dev/null)
|
||||||
if [ $? -eq 0 ]; then
|
if [ $? -eq 0 ]; then
|
||||||
log_success "Kong Admin API 正常"
|
log_success "Kong Admin API 正常"
|
||||||
echo "$response" | python3 -m json.tool 2>/dev/null || echo "$response"
|
echo "$response" | python3 -m json.tool 2>/dev/null || echo "$response"
|
||||||
else
|
else
|
||||||
log_error "Kong Admin API 不可用"
|
log_error "Kong Admin API 不可用"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
# 重载配置 (触发 deck sync)
|
# 重载配置 (触发 deck sync)
|
||||||
cmd_reload() {
|
cmd_reload() {
|
||||||
log_info "重载 Kong 配置..."
|
log_info "重载 Kong 配置..."
|
||||||
$COMPOSE_CMD run --rm kong-config
|
$COMPOSE_CMD run --rm kong-config
|
||||||
log_success "配置已重载"
|
log_success "配置已重载"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 同步配置到数据库
|
# 同步配置到数据库
|
||||||
cmd_sync() {
|
cmd_sync() {
|
||||||
log_info "同步 kong.yml 配置到数据库..."
|
log_info "同步 kong.yml 配置到数据库..."
|
||||||
$COMPOSE_CMD run --rm kong-config
|
$COMPOSE_CMD run --rm kong-config
|
||||||
log_success "配置同步完成"
|
log_success "配置同步完成"
|
||||||
echo ""
|
echo ""
|
||||||
echo "查看路由: ./deploy.sh routes"
|
echo "查看路由: ./deploy.sh routes"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 查看所有路由
|
# 查看所有路由
|
||||||
cmd_routes() {
|
cmd_routes() {
|
||||||
log_info "Kong 路由列表:"
|
log_info "Kong 路由列表:"
|
||||||
curl -s $KONG_ADMIN_URL/routes | python3 -m json.tool 2>/dev/null || curl -s $KONG_ADMIN_URL/routes
|
curl -s $KONG_ADMIN_URL/routes | python3 -m json.tool 2>/dev/null || curl -s $KONG_ADMIN_URL/routes
|
||||||
}
|
}
|
||||||
|
|
||||||
# 查看所有服务
|
# 查看所有服务
|
||||||
cmd_services() {
|
cmd_services() {
|
||||||
log_info "Kong 服务列表:"
|
log_info "Kong 服务列表:"
|
||||||
curl -s $KONG_ADMIN_URL/services | python3 -m json.tool 2>/dev/null || curl -s $KONG_ADMIN_URL/services
|
curl -s $KONG_ADMIN_URL/services | python3 -m json.tool 2>/dev/null || curl -s $KONG_ADMIN_URL/services
|
||||||
}
|
}
|
||||||
|
|
||||||
# 测试 API
|
# 测试 API
|
||||||
cmd_test() {
|
cmd_test() {
|
||||||
log_info "测试 API 路由..."
|
log_info "测试 API 路由..."
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "测试 /api/v1/versions (admin-service):"
|
echo "测试 /api/v1/versions (admin-service):"
|
||||||
curl -s -o /dev/null -w " HTTP Status: %{http_code}\n" $KONG_PROXY_URL/api/v1/versions
|
curl -s -o /dev/null -w " HTTP Status: %{http_code}\n" $KONG_PROXY_URL/api/v1/versions
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "测试 /api/v1/auth (identity-service):"
|
echo "测试 /api/v1/auth (identity-service):"
|
||||||
curl -s -o /dev/null -w " HTTP Status: %{http_code}\n" $KONG_PROXY_URL/api/v1/auth
|
curl -s -o /dev/null -w " HTTP Status: %{http_code}\n" $KONG_PROXY_URL/api/v1/auth
|
||||||
}
|
}
|
||||||
|
|
||||||
# 清理
|
# 清理
|
||||||
cmd_clean() {
|
cmd_clean() {
|
||||||
log_info "清理 Kong 容器和数据..."
|
log_info "清理 Kong 容器和数据..."
|
||||||
$COMPOSE_CMD down -v --remove-orphans
|
$COMPOSE_CMD down -v --remove-orphans
|
||||||
docker image prune -f
|
docker image prune -f
|
||||||
log_success "清理完成"
|
log_success "清理完成"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 启动监控栈
|
# 启动监控栈
|
||||||
cmd_monitoring_up() {
|
cmd_monitoring_up() {
|
||||||
log_info "启动监控栈 (Prometheus + Grafana)..."
|
log_info "启动监控栈 (Prometheus + Grafana)..."
|
||||||
$COMPOSE_CMD -f docker-compose.yml -f docker-compose.monitoring.yml up -d prometheus grafana
|
$COMPOSE_CMD -f docker-compose.yml -f docker-compose.monitoring.yml up -d prometheus grafana
|
||||||
|
|
||||||
log_info "等待服务启动..."
|
log_info "等待服务启动..."
|
||||||
sleep 5
|
sleep 5
|
||||||
|
|
||||||
log_success "监控栈启动成功!"
|
log_success "监控栈启动成功!"
|
||||||
echo ""
|
echo ""
|
||||||
echo "监控服务地址:"
|
echo "监控服务地址:"
|
||||||
echo " Grafana: http://localhost:3030 (admin/admin123)"
|
echo " Grafana: http://localhost:3030 (admin/admin123)"
|
||||||
echo " Prometheus: http://localhost:9099"
|
echo " Prometheus: http://localhost:9099"
|
||||||
echo " Kong 指标: http://localhost:8001/metrics"
|
echo " Kong 指标: http://localhost:8001/metrics"
|
||||||
echo ""
|
echo ""
|
||||||
}
|
}
|
||||||
|
|
||||||
# 安装监控栈 (包括 Nginx + SSL)
|
# 安装监控栈 (包括 Nginx + SSL)
|
||||||
cmd_monitoring_install() {
|
cmd_monitoring_install() {
|
||||||
local domain="${1:-monitor.szaiai.com}"
|
local domain="${1:-monitor.szaiai.com}"
|
||||||
log_info "安装监控栈..."
|
log_info "安装监控栈..."
|
||||||
|
|
||||||
if [ ! -f "$SCRIPT_DIR/scripts/install-monitor.sh" ]; then
|
if [ ! -f "$SCRIPT_DIR/scripts/install-monitor.sh" ]; then
|
||||||
log_error "安装脚本不存在: scripts/install-monitor.sh"
|
log_error "安装脚本不存在: scripts/install-monitor.sh"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
sudo bash "$SCRIPT_DIR/scripts/install-monitor.sh" "$domain"
|
sudo bash "$SCRIPT_DIR/scripts/install-monitor.sh" "$domain"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 停止监控栈
|
# 停止监控栈
|
||||||
cmd_monitoring_down() {
|
cmd_monitoring_down() {
|
||||||
log_info "停止监控栈..."
|
log_info "停止监控栈..."
|
||||||
docker stop rwa-prometheus rwa-grafana 2>/dev/null || true
|
docker stop rwa-prometheus rwa-grafana 2>/dev/null || true
|
||||||
docker rm rwa-prometheus rwa-grafana 2>/dev/null || true
|
docker rm rwa-prometheus rwa-grafana 2>/dev/null || true
|
||||||
log_success "监控栈已停止"
|
log_success "监控栈已停止"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 查看 Prometheus 指标
|
# 查看 Prometheus 指标
|
||||||
cmd_metrics() {
|
cmd_metrics() {
|
||||||
log_info "Kong Prometheus 指标概览:"
|
log_info "Kong Prometheus 指标概览:"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# 获取关键指标
|
# 获取关键指标
|
||||||
metrics=$(curl -s $KONG_ADMIN_URL/metrics 2>/dev/null)
|
metrics=$(curl -s $KONG_ADMIN_URL/metrics 2>/dev/null)
|
||||||
|
|
||||||
if [ $? -eq 0 ]; then
|
if [ $? -eq 0 ]; then
|
||||||
echo "=== 请求统计 ==="
|
echo "=== 请求统计 ==="
|
||||||
echo "$metrics" | grep -E "^kong_http_requests_total" | head -20
|
echo "$metrics" | grep -E "^kong_http_requests_total" | head -20
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 延迟统计 ==="
|
echo "=== 延迟统计 ==="
|
||||||
echo "$metrics" | grep -E "^kong_latency_" | head -10
|
echo "$metrics" | grep -E "^kong_latency_" | head -10
|
||||||
echo ""
|
echo ""
|
||||||
echo "完整指标: curl $KONG_ADMIN_URL/metrics"
|
echo "完整指标: curl $KONG_ADMIN_URL/metrics"
|
||||||
else
|
else
|
||||||
log_error "无法获取指标,请确保 Kong 正在运行且 prometheus 插件已启用"
|
log_error "无法获取指标,请确保 Kong 正在运行且 prometheus 插件已启用"
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
# 显示帮助
|
# 显示帮助
|
||||||
show_help() {
|
show_help() {
|
||||||
echo ""
|
echo ""
|
||||||
echo "RWADurian API Gateway (Kong) 部署脚本"
|
echo "RWADurian API Gateway (Kong) 部署脚本"
|
||||||
echo ""
|
echo ""
|
||||||
echo "用法: ./deploy.sh [命令]"
|
echo "用法: ./deploy.sh [命令]"
|
||||||
echo ""
|
echo ""
|
||||||
echo "命令:"
|
echo "命令:"
|
||||||
echo " up 启动 Kong 网关"
|
echo " up 启动 Kong 网关"
|
||||||
echo " down 停止 Kong 网关"
|
echo " down 停止 Kong 网关"
|
||||||
echo " restart 重启 Kong 网关"
|
echo " restart 重启 Kong 网关"
|
||||||
echo " logs 查看日志"
|
echo " logs 查看日志"
|
||||||
echo " status 查看状态"
|
echo " status 查看状态"
|
||||||
echo " health 健康检查"
|
echo " health 健康检查"
|
||||||
echo " sync 同步 kong.yml 配置到数据库"
|
echo " sync 同步 kong.yml 配置到数据库"
|
||||||
echo " reload 重载 Kong 配置 (同 sync)"
|
echo " reload 重载 Kong 配置 (同 sync)"
|
||||||
echo " routes 查看所有路由"
|
echo " routes 查看所有路由"
|
||||||
echo " services 查看所有服务"
|
echo " services 查看所有服务"
|
||||||
echo " test 测试 API 路由"
|
echo " test 测试 API 路由"
|
||||||
echo " clean 清理容器和数据"
|
echo " clean 清理容器和数据"
|
||||||
echo ""
|
echo ""
|
||||||
echo "监控命令:"
|
echo "监控命令:"
|
||||||
echo " monitoring install [domain] 一键安装监控 (Nginx+SSL+服务)"
|
echo " monitoring install [domain] 一键安装监控 (Nginx+SSL+服务)"
|
||||||
echo " monitoring up 启动监控栈"
|
echo " monitoring up 启动监控栈"
|
||||||
echo " monitoring down 停止监控栈"
|
echo " monitoring down 停止监控栈"
|
||||||
echo " metrics 查看 Prometheus 指标"
|
echo " metrics 查看 Prometheus 指标"
|
||||||
echo ""
|
echo ""
|
||||||
echo " help 显示帮助"
|
echo " help 显示帮助"
|
||||||
echo ""
|
echo ""
|
||||||
echo "注意: 需要先启动 backend/services 才能启动 Kong"
|
echo "注意: 需要先启动 backend/services 才能启动 Kong"
|
||||||
echo ""
|
echo ""
|
||||||
}
|
}
|
||||||
|
|
||||||
# 主函数
|
# 主函数
|
||||||
main() {
|
main() {
|
||||||
check_docker
|
check_docker
|
||||||
check_docker_compose
|
check_docker_compose
|
||||||
|
|
||||||
case "${1:-help}" in
|
case "${1:-help}" in
|
||||||
up)
|
up)
|
||||||
cmd_up
|
cmd_up
|
||||||
;;
|
;;
|
||||||
down)
|
down)
|
||||||
cmd_down
|
cmd_down
|
||||||
;;
|
;;
|
||||||
restart)
|
restart)
|
||||||
cmd_restart
|
cmd_restart
|
||||||
;;
|
;;
|
||||||
logs)
|
logs)
|
||||||
cmd_logs
|
cmd_logs
|
||||||
;;
|
;;
|
||||||
status)
|
status)
|
||||||
cmd_status
|
cmd_status
|
||||||
;;
|
;;
|
||||||
health)
|
health)
|
||||||
cmd_health
|
cmd_health
|
||||||
;;
|
;;
|
||||||
sync)
|
sync)
|
||||||
cmd_sync
|
cmd_sync
|
||||||
;;
|
;;
|
||||||
reload)
|
reload)
|
||||||
cmd_reload
|
cmd_reload
|
||||||
;;
|
;;
|
||||||
routes)
|
routes)
|
||||||
cmd_routes
|
cmd_routes
|
||||||
;;
|
;;
|
||||||
services)
|
services)
|
||||||
cmd_services
|
cmd_services
|
||||||
;;
|
;;
|
||||||
test)
|
test)
|
||||||
cmd_test
|
cmd_test
|
||||||
;;
|
;;
|
||||||
clean)
|
clean)
|
||||||
cmd_clean
|
cmd_clean
|
||||||
;;
|
;;
|
||||||
monitoring)
|
monitoring)
|
||||||
case "${2:-up}" in
|
case "${2:-up}" in
|
||||||
install)
|
install)
|
||||||
cmd_monitoring_install "$3"
|
cmd_monitoring_install "$3"
|
||||||
;;
|
;;
|
||||||
up)
|
up)
|
||||||
cmd_monitoring_up
|
cmd_monitoring_up
|
||||||
;;
|
;;
|
||||||
down)
|
down)
|
||||||
cmd_monitoring_down
|
cmd_monitoring_down
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
log_error "未知监控命令: $2"
|
log_error "未知监控命令: $2"
|
||||||
echo "用法: ./deploy.sh monitoring [install|up|down]"
|
echo "用法: ./deploy.sh monitoring [install|up|down]"
|
||||||
exit 1
|
exit 1
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
;;
|
;;
|
||||||
metrics)
|
metrics)
|
||||||
cmd_metrics
|
cmd_metrics
|
||||||
;;
|
;;
|
||||||
help|--help|-h)
|
help|--help|-h)
|
||||||
show_help
|
show_help
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
log_error "未知命令: $1"
|
log_error "未知命令: $1"
|
||||||
show_help
|
show_help
|
||||||
exit 1
|
exit 1
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
}
|
}
|
||||||
|
|
||||||
main "$@"
|
main "$@"
|
||||||
|
|
|
||||||
|
|
@ -1,67 +1,67 @@
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Kong Monitoring Stack - Prometheus + Grafana
|
# Kong Monitoring Stack - Prometheus + Grafana
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Usage:
|
# Usage:
|
||||||
# docker compose -f docker-compose.yml -f docker-compose.monitoring.yml up -d
|
# docker compose -f docker-compose.yml -f docker-compose.monitoring.yml up -d
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
services:
|
services:
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
# Prometheus - 指标收集
|
# Prometheus - 指标收集
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
prometheus:
|
prometheus:
|
||||||
image: prom/prometheus:latest
|
image: prom/prometheus:latest
|
||||||
container_name: rwa-prometheus
|
container_name: rwa-prometheus
|
||||||
command:
|
command:
|
||||||
- '--config.file=/etc/prometheus/prometheus.yml'
|
- '--config.file=/etc/prometheus/prometheus.yml'
|
||||||
- '--storage.tsdb.path=/prometheus'
|
- '--storage.tsdb.path=/prometheus'
|
||||||
- '--web.console.libraries=/usr/share/prometheus/console_libraries'
|
- '--web.console.libraries=/usr/share/prometheus/console_libraries'
|
||||||
- '--web.console.templates=/usr/share/prometheus/consoles'
|
- '--web.console.templates=/usr/share/prometheus/consoles'
|
||||||
volumes:
|
volumes:
|
||||||
- ./prometheus.yml:/etc/prometheus/prometheus.yml:ro
|
- ./prometheus.yml:/etc/prometheus/prometheus.yml:ro
|
||||||
- prometheus_data:/prometheus
|
- prometheus_data:/prometheus
|
||||||
ports:
|
ports:
|
||||||
- "9099:9090" # 使用 9099 避免与已有服务冲突
|
- "9099:9090" # 使用 9099 避免与已有服务冲突
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks:
|
networks:
|
||||||
- rwa-network
|
- rwa-network
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
# Grafana - 可视化仪表盘
|
# Grafana - 可视化仪表盘
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
grafana:
|
grafana:
|
||||||
image: grafana/grafana:latest
|
image: grafana/grafana:latest
|
||||||
container_name: rwa-grafana
|
container_name: rwa-grafana
|
||||||
environment:
|
environment:
|
||||||
- GF_SECURITY_ADMIN_USER=admin
|
- GF_SECURITY_ADMIN_USER=admin
|
||||||
- GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_ADMIN_PASSWORD:-admin123}
|
- GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_ADMIN_PASSWORD:-admin123}
|
||||||
- GF_USERS_ALLOW_SIGN_UP=false
|
- GF_USERS_ALLOW_SIGN_UP=false
|
||||||
# 反向代理支持
|
# 反向代理支持
|
||||||
- GF_SERVER_ROOT_URL=${GRAFANA_ROOT_URL:-http://localhost:3030}
|
- GF_SERVER_ROOT_URL=${GRAFANA_ROOT_URL:-http://localhost:3030}
|
||||||
- GF_SERVER_SERVE_FROM_SUB_PATH=false
|
- GF_SERVER_SERVE_FROM_SUB_PATH=false
|
||||||
# Grafana 10+ CORS/跨域配置 - 允许通过反向代理访问
|
# Grafana 10+ CORS/跨域配置 - 允许通过反向代理访问
|
||||||
- GF_SECURITY_ALLOW_EMBEDDING=true
|
- GF_SECURITY_ALLOW_EMBEDDING=true
|
||||||
- GF_SECURITY_COOKIE_SAMESITE=none
|
- GF_SECURITY_COOKIE_SAMESITE=none
|
||||||
- GF_SECURITY_COOKIE_SECURE=true
|
- GF_SECURITY_COOKIE_SECURE=true
|
||||||
- GF_AUTH_ANONYMOUS_ENABLED=false
|
- GF_AUTH_ANONYMOUS_ENABLED=false
|
||||||
volumes:
|
volumes:
|
||||||
- grafana_data:/var/lib/grafana
|
- grafana_data:/var/lib/grafana
|
||||||
- ./grafana/provisioning:/etc/grafana/provisioning:ro
|
- ./grafana/provisioning:/etc/grafana/provisioning:ro
|
||||||
ports:
|
ports:
|
||||||
- "3030:3000"
|
- "3030:3000"
|
||||||
depends_on:
|
depends_on:
|
||||||
- prometheus
|
- prometheus
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks:
|
networks:
|
||||||
- rwa-network
|
- rwa-network
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
prometheus_data:
|
prometheus_data:
|
||||||
driver: local
|
driver: local
|
||||||
grafana_data:
|
grafana_data:
|
||||||
driver: local
|
driver: local
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
rwa-network:
|
rwa-network:
|
||||||
external: true
|
external: true
|
||||||
name: ${NETWORK_NAME:-api-gateway_rwa-network}
|
name: ${NETWORK_NAME:-api-gateway_rwa-network}
|
||||||
|
|
|
||||||
|
|
@ -1,129 +1,129 @@
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Kong API Gateway - Docker Compose
|
# Kong API Gateway - Docker Compose
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Usage:
|
# Usage:
|
||||||
# ./deploy.sh up # 启动 Kong 网关
|
# ./deploy.sh up # 启动 Kong 网关
|
||||||
# ./deploy.sh down # 停止 Kong 网关
|
# ./deploy.sh down # 停止 Kong 网关
|
||||||
# ./deploy.sh logs # 查看日志
|
# ./deploy.sh logs # 查看日志
|
||||||
# ./deploy.sh status # 查看状态
|
# ./deploy.sh status # 查看状态
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
services:
|
services:
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
# Kong Database
|
# Kong Database
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
kong-db:
|
kong-db:
|
||||||
image: docker.io/library/postgres:16-alpine
|
image: docker.io/library/postgres:16-alpine
|
||||||
container_name: rwa-kong-db
|
container_name: rwa-kong-db
|
||||||
environment:
|
environment:
|
||||||
POSTGRES_USER: kong
|
POSTGRES_USER: kong
|
||||||
POSTGRES_PASSWORD: ${KONG_PG_PASSWORD:-kong_password}
|
POSTGRES_PASSWORD: ${KONG_PG_PASSWORD:-kong_password}
|
||||||
POSTGRES_DB: kong
|
POSTGRES_DB: kong
|
||||||
volumes:
|
volumes:
|
||||||
- kong_db_data:/var/lib/postgresql/data
|
- kong_db_data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD-SHELL", "pg_isready -U kong"]
|
test: ["CMD-SHELL", "pg_isready -U kong"]
|
||||||
interval: 5s
|
interval: 5s
|
||||||
timeout: 5s
|
timeout: 5s
|
||||||
retries: 10
|
retries: 10
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks:
|
networks:
|
||||||
- rwa-network
|
- rwa-network
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
# Kong Migrations (只运行一次)
|
# Kong Migrations (只运行一次)
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
kong-migrations:
|
kong-migrations:
|
||||||
image: docker.io/kong/kong-gateway:3.5
|
image: docker.io/kong/kong-gateway:3.5
|
||||||
container_name: rwa-kong-migrations
|
container_name: rwa-kong-migrations
|
||||||
command: kong migrations bootstrap
|
command: kong migrations bootstrap
|
||||||
environment:
|
environment:
|
||||||
KONG_DATABASE: postgres
|
KONG_DATABASE: postgres
|
||||||
KONG_PG_HOST: kong-db
|
KONG_PG_HOST: kong-db
|
||||||
KONG_PG_USER: kong
|
KONG_PG_USER: kong
|
||||||
KONG_PG_PASSWORD: ${KONG_PG_PASSWORD:-kong_password}
|
KONG_PG_PASSWORD: ${KONG_PG_PASSWORD:-kong_password}
|
||||||
KONG_PG_DATABASE: kong
|
KONG_PG_DATABASE: kong
|
||||||
depends_on:
|
depends_on:
|
||||||
kong-db:
|
kong-db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
restart: on-failure
|
restart: on-failure
|
||||||
networks:
|
networks:
|
||||||
- rwa-network
|
- rwa-network
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
# Kong API Gateway
|
# Kong API Gateway
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
kong:
|
kong:
|
||||||
image: docker.io/kong/kong-gateway:3.5
|
image: docker.io/kong/kong-gateway:3.5
|
||||||
container_name: rwa-kong
|
container_name: rwa-kong
|
||||||
environment:
|
environment:
|
||||||
KONG_DATABASE: postgres
|
KONG_DATABASE: postgres
|
||||||
KONG_PG_HOST: kong-db
|
KONG_PG_HOST: kong-db
|
||||||
KONG_PG_USER: kong
|
KONG_PG_USER: kong
|
||||||
KONG_PG_PASSWORD: ${KONG_PG_PASSWORD:-kong_password}
|
KONG_PG_PASSWORD: ${KONG_PG_PASSWORD:-kong_password}
|
||||||
KONG_PG_DATABASE: kong
|
KONG_PG_DATABASE: kong
|
||||||
KONG_PROXY_ACCESS_LOG: /dev/stdout
|
KONG_PROXY_ACCESS_LOG: /dev/stdout
|
||||||
KONG_ADMIN_ACCESS_LOG: /dev/stdout
|
KONG_ADMIN_ACCESS_LOG: /dev/stdout
|
||||||
KONG_PROXY_ERROR_LOG: /dev/stderr
|
KONG_PROXY_ERROR_LOG: /dev/stderr
|
||||||
KONG_ADMIN_ERROR_LOG: /dev/stderr
|
KONG_ADMIN_ERROR_LOG: /dev/stderr
|
||||||
KONG_ADMIN_LISTEN: 0.0.0.0:8001
|
KONG_ADMIN_LISTEN: 0.0.0.0:8001
|
||||||
KONG_ADMIN_GUI_URL: ${KONG_ADMIN_GUI_URL:-http://localhost:8002}
|
KONG_ADMIN_GUI_URL: ${KONG_ADMIN_GUI_URL:-http://localhost:8002}
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000" # Proxy HTTP
|
- "8000:8000" # Proxy HTTP
|
||||||
- "8443:8443" # Proxy HTTPS
|
- "8443:8443" # Proxy HTTPS
|
||||||
- "8001:8001" # Admin API
|
- "8001:8001" # Admin API
|
||||||
- "8002:8002" # Admin GUI
|
- "8002:8002" # Admin GUI
|
||||||
depends_on:
|
depends_on:
|
||||||
kong-db:
|
kong-db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
kong-migrations:
|
kong-migrations:
|
||||||
condition: service_completed_successfully
|
condition: service_completed_successfully
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "kong", "health"]
|
test: ["CMD", "kong", "health"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 5
|
retries: 5
|
||||||
start_period: 30s
|
start_period: 30s
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks:
|
networks:
|
||||||
- rwa-network
|
- rwa-network
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
# Kong Config Loader - 导入声明式配置到数据库
|
# Kong Config Loader - 导入声明式配置到数据库
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
kong-config:
|
kong-config:
|
||||||
image: docker.io/kong/deck:latest
|
image: docker.io/kong/deck:latest
|
||||||
container_name: rwa-kong-config
|
container_name: rwa-kong-config
|
||||||
command: >
|
command: >
|
||||||
gateway sync /etc/kong/kong.yml
|
gateway sync /etc/kong/kong.yml
|
||||||
--kong-addr http://kong:8001
|
--kong-addr http://kong:8001
|
||||||
environment:
|
environment:
|
||||||
# 禁用代理,避免继承宿主机的代理设置
|
# 禁用代理,避免继承宿主机的代理设置
|
||||||
http_proxy: ""
|
http_proxy: ""
|
||||||
https_proxy: ""
|
https_proxy: ""
|
||||||
HTTP_PROXY: ""
|
HTTP_PROXY: ""
|
||||||
HTTPS_PROXY: ""
|
HTTPS_PROXY: ""
|
||||||
no_proxy: "*"
|
no_proxy: "*"
|
||||||
NO_PROXY: "*"
|
NO_PROXY: "*"
|
||||||
volumes:
|
volumes:
|
||||||
- ./kong.yml:/etc/kong/kong.yml:ro
|
- ./kong.yml:/etc/kong/kong.yml:ro
|
||||||
depends_on:
|
depends_on:
|
||||||
kong:
|
kong:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
restart: on-failure
|
restart: on-failure
|
||||||
networks:
|
networks:
|
||||||
- rwa-network
|
- rwa-network
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
# Volumes
|
# Volumes
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
volumes:
|
volumes:
|
||||||
kong_db_data:
|
kong_db_data:
|
||||||
driver: local
|
driver: local
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
# Networks - 独立网络(分布式部署,Kong 通过外部 IP 访问后端服务)
|
# Networks - 独立网络(分布式部署,Kong 通过外部 IP 访问后端服务)
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
networks:
|
networks:
|
||||||
rwa-network:
|
rwa-network:
|
||||||
driver: bridge
|
driver: bridge
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
apiVersion: 1
|
apiVersion: 1
|
||||||
|
|
||||||
providers:
|
providers:
|
||||||
- name: 'Kong API Gateway'
|
- name: 'Kong API Gateway'
|
||||||
orgId: 1
|
orgId: 1
|
||||||
folder: ''
|
folder: ''
|
||||||
type: file
|
type: file
|
||||||
disableDeletion: false
|
disableDeletion: false
|
||||||
updateIntervalSeconds: 10
|
updateIntervalSeconds: 10
|
||||||
options:
|
options:
|
||||||
path: /etc/grafana/provisioning/dashboards
|
path: /etc/grafana/provisioning/dashboards
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,9 +1,9 @@
|
||||||
apiVersion: 1
|
apiVersion: 1
|
||||||
|
|
||||||
datasources:
|
datasources:
|
||||||
- name: Prometheus
|
- name: Prometheus
|
||||||
type: prometheus
|
type: prometheus
|
||||||
access: proxy
|
access: proxy
|
||||||
url: http://prometheus:9090
|
url: http://prometheus:9090
|
||||||
isDefault: true
|
isDefault: true
|
||||||
editable: false
|
editable: false
|
||||||
|
|
|
||||||
|
|
@ -1,245 +1,245 @@
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Kong API Gateway - 声明式配置
|
# Kong API Gateway - 声明式配置
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 分布式部署说明:
|
# 分布式部署说明:
|
||||||
# - Kong 服务器: 192.168.1.100
|
# - Kong 服务器: 192.168.1.100
|
||||||
# - 后端服务器: 192.168.1.111
|
# - 后端服务器: 192.168.1.111
|
||||||
#
|
#
|
||||||
# 使用方法:
|
# 使用方法:
|
||||||
# 1. 启动 Kong: ./deploy.sh up
|
# 1. 启动 Kong: ./deploy.sh up
|
||||||
# 2. 配置会自动加载
|
# 2. 配置会自动加载
|
||||||
#
|
#
|
||||||
# 文档: https://docs.konghq.com/gateway/latest/
|
# 文档: https://docs.konghq.com/gateway/latest/
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
_format_version: "3.0"
|
_format_version: "3.0"
|
||||||
_transform: true
|
_transform: true
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Services - 后端微服务定义
|
# Services - 后端微服务定义
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 注意: 使用外部 IP 地址,因为 Kong 和后端服务在不同服务器上
|
# 注意: 使用外部 IP 地址,因为 Kong 和后端服务在不同服务器上
|
||||||
# 后端服务器 IP: 192.168.1.111
|
# 后端服务器 IP: 192.168.1.111
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
services:
|
services:
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Identity Service - 身份认证服务
|
# Identity Service - 身份认证服务
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
- name: identity-service
|
- name: identity-service
|
||||||
url: http://192.168.1.111:3000
|
url: http://192.168.1.111:3000
|
||||||
routes:
|
routes:
|
||||||
- name: identity-auth
|
- name: identity-auth
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/auth
|
- /api/v1/auth
|
||||||
strip_path: false
|
strip_path: false
|
||||||
- name: identity-user
|
- name: identity-user
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/user
|
- /api/v1/user
|
||||||
strip_path: false
|
strip_path: false
|
||||||
- name: identity-users
|
- name: identity-users
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/users
|
- /api/v1/users
|
||||||
strip_path: false
|
strip_path: false
|
||||||
- name: identity-health
|
- name: identity-health
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/identity/health
|
- /api/v1/identity/health
|
||||||
strip_path: true
|
strip_path: true
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Wallet Service - 钱包服务
|
# Wallet Service - 钱包服务
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
- name: wallet-service
|
- name: wallet-service
|
||||||
url: http://192.168.1.111:3001
|
url: http://192.168.1.111:3001
|
||||||
routes:
|
routes:
|
||||||
- name: wallet-api
|
- name: wallet-api
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/wallets
|
- /api/v1/wallets
|
||||||
strip_path: false
|
strip_path: false
|
||||||
- name: wallet-health
|
- name: wallet-health
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/wallet/health
|
- /api/v1/wallet/health
|
||||||
strip_path: true
|
strip_path: true
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Backup Service - 备份服务
|
# Backup Service - 备份服务
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
- name: backup-service
|
- name: backup-service
|
||||||
url: http://192.168.1.111:3002
|
url: http://192.168.1.111:3002
|
||||||
routes:
|
routes:
|
||||||
- name: backup-api
|
- name: backup-api
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/backups
|
- /api/v1/backups
|
||||||
strip_path: false
|
strip_path: false
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Planting Service - 种植服务
|
# Planting Service - 种植服务
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
- name: planting-service
|
- name: planting-service
|
||||||
url: http://192.168.1.111:3003
|
url: http://192.168.1.111:3003
|
||||||
routes:
|
routes:
|
||||||
- name: planting-api
|
- name: planting-api
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/plantings
|
- /api/v1/plantings
|
||||||
- /api/v1/trees
|
- /api/v1/trees
|
||||||
strip_path: false
|
strip_path: false
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Referral Service - 推荐服务
|
# Referral Service - 推荐服务
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
- name: referral-service
|
- name: referral-service
|
||||||
url: http://192.168.1.111:3004
|
url: http://192.168.1.111:3004
|
||||||
routes:
|
routes:
|
||||||
- name: referral-api
|
- name: referral-api
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/referrals
|
- /api/v1/referrals
|
||||||
strip_path: false
|
strip_path: false
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Reward Service - 奖励服务
|
# Reward Service - 奖励服务
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
- name: reward-service
|
- name: reward-service
|
||||||
url: http://192.168.1.111:3005
|
url: http://192.168.1.111:3005
|
||||||
routes:
|
routes:
|
||||||
- name: reward-api
|
- name: reward-api
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/rewards
|
- /api/v1/rewards
|
||||||
strip_path: false
|
strip_path: false
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# MPC Service - 多方计算服务
|
# MPC Service - 多方计算服务
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
- name: mpc-service
|
- name: mpc-service
|
||||||
url: http://192.168.1.111:3006
|
url: http://192.168.1.111:3006
|
||||||
routes:
|
routes:
|
||||||
- name: mpc-api
|
- name: mpc-api
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/mpc
|
- /api/v1/mpc
|
||||||
strip_path: false
|
strip_path: false
|
||||||
- name: mpc-party-api
|
- name: mpc-party-api
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/mpc-party
|
- /api/v1/mpc-party
|
||||||
strip_path: false
|
strip_path: false
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Leaderboard Service - 排行榜服务
|
# Leaderboard Service - 排行榜服务
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
- name: leaderboard-service
|
- name: leaderboard-service
|
||||||
url: http://192.168.1.111:3007
|
url: http://192.168.1.111:3007
|
||||||
routes:
|
routes:
|
||||||
- name: leaderboard-api
|
- name: leaderboard-api
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/leaderboard
|
- /api/v1/leaderboard
|
||||||
strip_path: false
|
strip_path: false
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Reporting Service - 报表服务
|
# Reporting Service - 报表服务
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
- name: reporting-service
|
- name: reporting-service
|
||||||
url: http://192.168.1.111:3008
|
url: http://192.168.1.111:3008
|
||||||
routes:
|
routes:
|
||||||
- name: reporting-api
|
- name: reporting-api
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/reports
|
- /api/v1/reports
|
||||||
- /api/v1/statistics
|
- /api/v1/statistics
|
||||||
strip_path: false
|
strip_path: false
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Authorization Service - 授权服务
|
# Authorization Service - 授权服务
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
- name: authorization-service
|
- name: authorization-service
|
||||||
url: http://192.168.1.111:3009
|
url: http://192.168.1.111:3009
|
||||||
routes:
|
routes:
|
||||||
- name: authorization-api
|
- name: authorization-api
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/authorization
|
- /api/v1/authorization
|
||||||
- /api/v1/permissions
|
- /api/v1/permissions
|
||||||
- /api/v1/roles
|
- /api/v1/roles
|
||||||
strip_path: false
|
strip_path: false
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Admin Service - 管理服务 (包含版本管理)
|
# Admin Service - 管理服务 (包含版本管理)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
- name: admin-service
|
- name: admin-service
|
||||||
url: http://192.168.1.111:3010
|
url: http://192.168.1.111:3010
|
||||||
routes:
|
routes:
|
||||||
- name: admin-versions
|
- name: admin-versions
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/versions
|
- /api/v1/versions
|
||||||
strip_path: false
|
strip_path: false
|
||||||
- name: admin-api
|
- name: admin-api
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/admin
|
- /api/v1/admin
|
||||||
strip_path: false
|
strip_path: false
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Presence Service - 在线状态服务
|
# Presence Service - 在线状态服务
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
- name: presence-service
|
- name: presence-service
|
||||||
url: http://192.168.1.111:3011
|
url: http://192.168.1.111:3011
|
||||||
routes:
|
routes:
|
||||||
- name: presence-api
|
- name: presence-api
|
||||||
paths:
|
paths:
|
||||||
- /api/v1/presence
|
- /api/v1/presence
|
||||||
strip_path: false
|
strip_path: false
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Plugins - 全局插件配置
|
# Plugins - 全局插件配置
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
plugins:
|
plugins:
|
||||||
# CORS 跨域配置
|
# CORS 跨域配置
|
||||||
- name: cors
|
- name: cors
|
||||||
config:
|
config:
|
||||||
origins:
|
origins:
|
||||||
- "https://rwaadmin.szaiai.com"
|
- "https://rwaadmin.szaiai.com"
|
||||||
- "https://update.szaiai.com"
|
- "https://update.szaiai.com"
|
||||||
- "https://app.rwadurian.com"
|
- "https://app.rwadurian.com"
|
||||||
- "http://localhost:3000"
|
- "http://localhost:3000"
|
||||||
- "http://localhost:3020"
|
- "http://localhost:3020"
|
||||||
methods:
|
methods:
|
||||||
- GET
|
- GET
|
||||||
- POST
|
- POST
|
||||||
- PUT
|
- PUT
|
||||||
- PATCH
|
- PATCH
|
||||||
- DELETE
|
- DELETE
|
||||||
- OPTIONS
|
- OPTIONS
|
||||||
headers:
|
headers:
|
||||||
- Accept
|
- Accept
|
||||||
- Accept-Version
|
- Accept-Version
|
||||||
- Content-Length
|
- Content-Length
|
||||||
- Content-MD5
|
- Content-MD5
|
||||||
- Content-Type
|
- Content-Type
|
||||||
- Date
|
- Date
|
||||||
- Authorization
|
- Authorization
|
||||||
- X-Auth-Token
|
- X-Auth-Token
|
||||||
exposed_headers:
|
exposed_headers:
|
||||||
- X-Auth-Token
|
- X-Auth-Token
|
||||||
credentials: true
|
credentials: true
|
||||||
max_age: 3600
|
max_age: 3600
|
||||||
|
|
||||||
# 请求限流
|
# 请求限流
|
||||||
- name: rate-limiting
|
- name: rate-limiting
|
||||||
config:
|
config:
|
||||||
minute: 100
|
minute: 100
|
||||||
hour: 5000
|
hour: 5000
|
||||||
policy: local
|
policy: local
|
||||||
|
|
||||||
# 请求日志
|
# 请求日志
|
||||||
- name: file-log
|
- name: file-log
|
||||||
config:
|
config:
|
||||||
path: /tmp/kong-access.log
|
path: /tmp/kong-access.log
|
||||||
reopen: true
|
reopen: true
|
||||||
|
|
||||||
# 请求/响应大小限制 (500MB 用于 APK/IPA 上传)
|
# 请求/响应大小限制 (500MB 用于 APK/IPA 上传)
|
||||||
- name: request-size-limiting
|
- name: request-size-limiting
|
||||||
config:
|
config:
|
||||||
allowed_payload_size: 500
|
allowed_payload_size: 500
|
||||||
size_unit: megabytes
|
size_unit: megabytes
|
||||||
|
|
||||||
# Prometheus 监控指标
|
# Prometheus 监控指标
|
||||||
- name: prometheus
|
- name: prometheus
|
||||||
config:
|
config:
|
||||||
per_consumer: true
|
per_consumer: true
|
||||||
status_code_metrics: true
|
status_code_metrics: true
|
||||||
latency_metrics: true
|
latency_metrics: true
|
||||||
bandwidth_metrics: true
|
bandwidth_metrics: true
|
||||||
upstream_health_metrics: true
|
upstream_health_metrics: true
|
||||||
|
|
|
||||||
|
|
@ -1,208 +1,208 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# RWADurian API Gateway - Nginx 完整安装脚本
|
# RWADurian API Gateway - Nginx 完整安装脚本
|
||||||
# 适用于全新 Ubuntu/Debian 服务器
|
# 适用于全新 Ubuntu/Debian 服务器
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
DOMAIN="rwaapi.szaiai.com"
|
DOMAIN="rwaapi.szaiai.com"
|
||||||
EMAIL="admin@szaiai.com" # 修改为你的邮箱
|
EMAIL="admin@szaiai.com" # 修改为你的邮箱
|
||||||
KONG_PORT=8000
|
KONG_PORT=8000
|
||||||
|
|
||||||
# 颜色
|
# 颜色
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
GREEN='\033[0;32m'
|
GREEN='\033[0;32m'
|
||||||
YELLOW='\033[1;33m'
|
YELLOW='\033[1;33m'
|
||||||
BLUE='\033[0;34m'
|
BLUE='\033[0;34m'
|
||||||
NC='\033[0m'
|
NC='\033[0m'
|
||||||
|
|
||||||
log_info() { echo -e "${BLUE}[INFO]${NC} $1"; }
|
log_info() { echo -e "${BLUE}[INFO]${NC} $1"; }
|
||||||
log_success() { echo -e "${GREEN}[SUCCESS]${NC} $1"; }
|
log_success() { echo -e "${GREEN}[SUCCESS]${NC} $1"; }
|
||||||
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
||||||
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
||||||
|
|
||||||
# 检查 root 权限
|
# 检查 root 权限
|
||||||
check_root() {
|
check_root() {
|
||||||
if [ "$EUID" -ne 0 ]; then
|
if [ "$EUID" -ne 0 ]; then
|
||||||
log_error "请使用 root 权限运行: sudo ./install.sh"
|
log_error "请使用 root 权限运行: sudo ./install.sh"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
# 步骤 1: 更新系统
|
# 步骤 1: 更新系统
|
||||||
update_system() {
|
update_system() {
|
||||||
log_info "步骤 1/6: 更新系统包..."
|
log_info "步骤 1/6: 更新系统包..."
|
||||||
apt update && apt upgrade -y
|
apt update && apt upgrade -y
|
||||||
log_success "系统更新完成"
|
log_success "系统更新完成"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 步骤 2: 安装 Nginx
|
# 步骤 2: 安装 Nginx
|
||||||
install_nginx() {
|
install_nginx() {
|
||||||
log_info "步骤 2/6: 安装 Nginx..."
|
log_info "步骤 2/6: 安装 Nginx..."
|
||||||
apt install -y nginx
|
apt install -y nginx
|
||||||
systemctl enable nginx
|
systemctl enable nginx
|
||||||
systemctl start nginx
|
systemctl start nginx
|
||||||
log_success "Nginx 安装完成"
|
log_success "Nginx 安装完成"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 步骤 3: 安装 Certbot
|
# 步骤 3: 安装 Certbot
|
||||||
install_certbot() {
|
install_certbot() {
|
||||||
log_info "步骤 3/6: 安装 Certbot..."
|
log_info "步骤 3/6: 安装 Certbot..."
|
||||||
apt install -y certbot python3-certbot-nginx
|
apt install -y certbot python3-certbot-nginx
|
||||||
log_success "Certbot 安装完成"
|
log_success "Certbot 安装完成"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 步骤 4: 配置 Nginx (HTTP)
|
# 步骤 4: 配置 Nginx (HTTP)
|
||||||
configure_nginx_http() {
|
configure_nginx_http() {
|
||||||
log_info "步骤 4/6: 配置 Nginx (HTTP 临时配置用于证书申请)..."
|
log_info "步骤 4/6: 配置 Nginx (HTTP 临时配置用于证书申请)..."
|
||||||
|
|
||||||
# 创建 certbot webroot 目录
|
# 创建 certbot webroot 目录
|
||||||
mkdir -p /var/www/certbot
|
mkdir -p /var/www/certbot
|
||||||
|
|
||||||
# 创建临时 HTTP 配置
|
# 创建临时 HTTP 配置
|
||||||
cat > /etc/nginx/sites-available/$DOMAIN << EOF
|
cat > /etc/nginx/sites-available/$DOMAIN << EOF
|
||||||
# 临时 HTTP 配置 - 用于 Let's Encrypt 验证
|
# 临时 HTTP 配置 - 用于 Let's Encrypt 验证
|
||||||
server {
|
server {
|
||||||
listen 80;
|
listen 80;
|
||||||
listen [::]:80;
|
listen [::]:80;
|
||||||
server_name $DOMAIN;
|
server_name $DOMAIN;
|
||||||
|
|
||||||
# Let's Encrypt 验证目录
|
# Let's Encrypt 验证目录
|
||||||
location /.well-known/acme-challenge/ {
|
location /.well-known/acme-challenge/ {
|
||||||
root /var/www/certbot;
|
root /var/www/certbot;
|
||||||
}
|
}
|
||||||
|
|
||||||
# 临时代理到 Kong
|
# 临时代理到 Kong
|
||||||
location / {
|
location / {
|
||||||
proxy_pass http://127.0.0.1:$KONG_PORT;
|
proxy_pass http://127.0.0.1:$KONG_PORT;
|
||||||
proxy_http_version 1.1;
|
proxy_http_version 1.1;
|
||||||
proxy_set_header Host \$host;
|
proxy_set_header Host \$host;
|
||||||
proxy_set_header X-Real-IP \$remote_addr;
|
proxy_set_header X-Real-IP \$remote_addr;
|
||||||
proxy_set_header X-Forwarded-For \$proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For \$proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto \$scheme;
|
proxy_set_header X-Forwarded-Proto \$scheme;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
# 启用站点
|
# 启用站点
|
||||||
ln -sf /etc/nginx/sites-available/$DOMAIN /etc/nginx/sites-enabled/
|
ln -sf /etc/nginx/sites-available/$DOMAIN /etc/nginx/sites-enabled/
|
||||||
|
|
||||||
# 测试并重载
|
# 测试并重载
|
||||||
nginx -t && systemctl reload nginx
|
nginx -t && systemctl reload nginx
|
||||||
log_success "Nginx HTTP 配置完成"
|
log_success "Nginx HTTP 配置完成"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 步骤 5: 申请 SSL 证书
|
# 步骤 5: 申请 SSL 证书
|
||||||
obtain_ssl_certificate() {
|
obtain_ssl_certificate() {
|
||||||
log_info "步骤 5/6: 申请 Let's Encrypt SSL 证书..."
|
log_info "步骤 5/6: 申请 Let's Encrypt SSL 证书..."
|
||||||
|
|
||||||
# 检查域名解析
|
# 检查域名解析
|
||||||
log_info "检查域名 $DOMAIN 解析..."
|
log_info "检查域名 $DOMAIN 解析..."
|
||||||
if ! host $DOMAIN > /dev/null 2>&1; then
|
if ! host $DOMAIN > /dev/null 2>&1; then
|
||||||
log_warn "无法解析域名 $DOMAIN,请确保 DNS 已正确配置"
|
log_warn "无法解析域名 $DOMAIN,请确保 DNS 已正确配置"
|
||||||
log_warn "继续尝试申请证书..."
|
log_warn "继续尝试申请证书..."
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# 申请证书
|
# 申请证书
|
||||||
certbot certonly \
|
certbot certonly \
|
||||||
--webroot \
|
--webroot \
|
||||||
--webroot-path=/var/www/certbot \
|
--webroot-path=/var/www/certbot \
|
||||||
--email $EMAIL \
|
--email $EMAIL \
|
||||||
--agree-tos \
|
--agree-tos \
|
||||||
--no-eff-email \
|
--no-eff-email \
|
||||||
-d $DOMAIN
|
-d $DOMAIN
|
||||||
|
|
||||||
log_success "SSL 证书申请成功"
|
log_success "SSL 证书申请成功"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 步骤 6: 配置 Nginx (HTTPS)
|
# 步骤 6: 配置 Nginx (HTTPS)
|
||||||
configure_nginx_https() {
|
configure_nginx_https() {
|
||||||
log_info "步骤 6/6: 配置 Nginx (HTTPS)..."
|
log_info "步骤 6/6: 配置 Nginx (HTTPS)..."
|
||||||
|
|
||||||
# 获取脚本所在目录
|
# 获取脚本所在目录
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
|
||||||
# 复制完整配置
|
# 复制完整配置
|
||||||
cp "$SCRIPT_DIR/rwaapi.szaiai.com.conf" /etc/nginx/sites-available/$DOMAIN
|
cp "$SCRIPT_DIR/rwaapi.szaiai.com.conf" /etc/nginx/sites-available/$DOMAIN
|
||||||
|
|
||||||
# 测试并重载
|
# 测试并重载
|
||||||
nginx -t && systemctl reload nginx
|
nginx -t && systemctl reload nginx
|
||||||
log_success "Nginx HTTPS 配置完成"
|
log_success "Nginx HTTPS 配置完成"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 配置证书自动续期
|
# 配置证书自动续期
|
||||||
setup_auto_renewal() {
|
setup_auto_renewal() {
|
||||||
log_info "配置证书自动续期..."
|
log_info "配置证书自动续期..."
|
||||||
certbot renew --dry-run
|
certbot renew --dry-run
|
||||||
log_success "证书自动续期已配置"
|
log_success "证书自动续期已配置"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 配置防火墙
|
# 配置防火墙
|
||||||
configure_firewall() {
|
configure_firewall() {
|
||||||
log_info "配置防火墙..."
|
log_info "配置防火墙..."
|
||||||
|
|
||||||
if command -v ufw &> /dev/null; then
|
if command -v ufw &> /dev/null; then
|
||||||
ufw allow 'Nginx Full'
|
ufw allow 'Nginx Full'
|
||||||
ufw allow OpenSSH
|
ufw allow OpenSSH
|
||||||
ufw --force enable
|
ufw --force enable
|
||||||
log_success "UFW 防火墙已配置"
|
log_success "UFW 防火墙已配置"
|
||||||
else
|
else
|
||||||
log_warn "未检测到 UFW,请手动配置防火墙开放 80 和 443 端口"
|
log_warn "未检测到 UFW,请手动配置防火墙开放 80 和 443 端口"
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
# 显示完成信息
|
# 显示完成信息
|
||||||
show_completion() {
|
show_completion() {
|
||||||
echo ""
|
echo ""
|
||||||
echo -e "${GREEN}========================================${NC}"
|
echo -e "${GREEN}========================================${NC}"
|
||||||
echo -e "${GREEN} 安装完成!${NC}"
|
echo -e "${GREEN} 安装完成!${NC}"
|
||||||
echo -e "${GREEN}========================================${NC}"
|
echo -e "${GREEN}========================================${NC}"
|
||||||
echo ""
|
echo ""
|
||||||
echo -e "API 网关地址: ${BLUE}https://$DOMAIN${NC}"
|
echo -e "API 网关地址: ${BLUE}https://$DOMAIN${NC}"
|
||||||
echo ""
|
echo ""
|
||||||
echo "架构:"
|
echo "架构:"
|
||||||
echo " 用户请求 → Nginx (SSL) → Kong (API Gateway) → 微服务"
|
echo " 用户请求 → Nginx (SSL) → Kong (API Gateway) → 微服务"
|
||||||
echo ""
|
echo ""
|
||||||
echo "常用命令:"
|
echo "常用命令:"
|
||||||
echo " 查看 Nginx 状态: systemctl status nginx"
|
echo " 查看 Nginx 状态: systemctl status nginx"
|
||||||
echo " 重载 Nginx: systemctl reload nginx"
|
echo " 重载 Nginx: systemctl reload nginx"
|
||||||
echo " 查看证书: certbot certificates"
|
echo " 查看证书: certbot certificates"
|
||||||
echo " 手动续期: certbot renew"
|
echo " 手动续期: certbot renew"
|
||||||
echo " 查看日志: tail -f /var/log/nginx/$DOMAIN.access.log"
|
echo " 查看日志: tail -f /var/log/nginx/$DOMAIN.access.log"
|
||||||
echo ""
|
echo ""
|
||||||
}
|
}
|
||||||
|
|
||||||
# 主函数
|
# 主函数
|
||||||
main() {
|
main() {
|
||||||
echo ""
|
echo ""
|
||||||
echo "============================================"
|
echo "============================================"
|
||||||
echo " RWADurian API Gateway - Nginx 安装脚本"
|
echo " RWADurian API Gateway - Nginx 安装脚本"
|
||||||
echo " 域名: $DOMAIN"
|
echo " 域名: $DOMAIN"
|
||||||
echo "============================================"
|
echo "============================================"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
check_root
|
check_root
|
||||||
update_system
|
update_system
|
||||||
install_nginx
|
install_nginx
|
||||||
install_certbot
|
install_certbot
|
||||||
configure_firewall
|
configure_firewall
|
||||||
configure_nginx_http
|
configure_nginx_http
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
log_warn "请确保以下条件已满足:"
|
log_warn "请确保以下条件已满足:"
|
||||||
echo " 1. 域名 $DOMAIN 的 DNS A 记录已指向本服务器 IP"
|
echo " 1. 域名 $DOMAIN 的 DNS A 记录已指向本服务器 IP"
|
||||||
echo " 2. Kong API Gateway 已在端口 $KONG_PORT 运行"
|
echo " 2. Kong API Gateway 已在端口 $KONG_PORT 运行"
|
||||||
echo ""
|
echo ""
|
||||||
read -p "是否继续申请 SSL 证书? (y/n): " confirm
|
read -p "是否继续申请 SSL 证书? (y/n): " confirm
|
||||||
|
|
||||||
if [ "$confirm" = "y" ] || [ "$confirm" = "Y" ]; then
|
if [ "$confirm" = "y" ] || [ "$confirm" = "Y" ]; then
|
||||||
obtain_ssl_certificate
|
obtain_ssl_certificate
|
||||||
configure_nginx_https
|
configure_nginx_https
|
||||||
setup_auto_renewal
|
setup_auto_renewal
|
||||||
show_completion
|
show_completion
|
||||||
else
|
else
|
||||||
log_info "已跳过 SSL 配置,当前为 HTTP 模式"
|
log_info "已跳过 SSL 配置,当前为 HTTP 模式"
|
||||||
log_info "稍后可运行: certbot --nginx -d $DOMAIN"
|
log_info "稍后可运行: certbot --nginx -d $DOMAIN"
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
main "$@"
|
main "$@"
|
||||||
|
|
|
||||||
|
|
@ -1,112 +1,112 @@
|
||||||
# RWADurian API Gateway Nginx 配置
|
# RWADurian API Gateway Nginx 配置
|
||||||
# 域名: rwaapi.szaiai.com
|
# 域名: rwaapi.szaiai.com
|
||||||
# 后端: Kong API Gateway (端口 8000)
|
# 后端: Kong API Gateway (端口 8000)
|
||||||
# 放置路径: /etc/nginx/sites-available/rwaapi.szaiai.com
|
# 放置路径: /etc/nginx/sites-available/rwaapi.szaiai.com
|
||||||
# 启用: ln -s /etc/nginx/sites-available/rwaapi.szaiai.com /etc/nginx/sites-enabled/
|
# 启用: ln -s /etc/nginx/sites-available/rwaapi.szaiai.com /etc/nginx/sites-enabled/
|
||||||
|
|
||||||
# HTTP 重定向到 HTTPS
|
# HTTP 重定向到 HTTPS
|
||||||
server {
|
server {
|
||||||
listen 80;
|
listen 80;
|
||||||
listen [::]:80;
|
listen [::]:80;
|
||||||
server_name rwaapi.szaiai.com;
|
server_name rwaapi.szaiai.com;
|
||||||
|
|
||||||
# Let's Encrypt 验证目录
|
# Let's Encrypt 验证目录
|
||||||
location /.well-known/acme-challenge/ {
|
location /.well-known/acme-challenge/ {
|
||||||
root /var/www/certbot;
|
root /var/www/certbot;
|
||||||
}
|
}
|
||||||
|
|
||||||
# 重定向到 HTTPS
|
# 重定向到 HTTPS
|
||||||
location / {
|
location / {
|
||||||
return 301 https://$host$request_uri;
|
return 301 https://$host$request_uri;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# HTTPS 配置
|
# HTTPS 配置
|
||||||
server {
|
server {
|
||||||
listen 443 ssl http2;
|
listen 443 ssl http2;
|
||||||
listen [::]:443 ssl http2;
|
listen [::]:443 ssl http2;
|
||||||
server_name rwaapi.szaiai.com;
|
server_name rwaapi.szaiai.com;
|
||||||
|
|
||||||
# SSL 证书 (Let's Encrypt)
|
# SSL 证书 (Let's Encrypt)
|
||||||
ssl_certificate /etc/letsencrypt/live/rwaapi.szaiai.com/fullchain.pem;
|
ssl_certificate /etc/letsencrypt/live/rwaapi.szaiai.com/fullchain.pem;
|
||||||
ssl_certificate_key /etc/letsencrypt/live/rwaapi.szaiai.com/privkey.pem;
|
ssl_certificate_key /etc/letsencrypt/live/rwaapi.szaiai.com/privkey.pem;
|
||||||
|
|
||||||
# SSL 配置优化
|
# SSL 配置优化
|
||||||
ssl_session_timeout 1d;
|
ssl_session_timeout 1d;
|
||||||
ssl_session_cache shared:SSL:50m;
|
ssl_session_cache shared:SSL:50m;
|
||||||
ssl_session_tickets off;
|
ssl_session_tickets off;
|
||||||
|
|
||||||
# 现代加密套件
|
# 现代加密套件
|
||||||
ssl_protocols TLSv1.2 TLSv1.3;
|
ssl_protocols TLSv1.2 TLSv1.3;
|
||||||
ssl_ciphers ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384;
|
ssl_ciphers ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384;
|
||||||
ssl_prefer_server_ciphers off;
|
ssl_prefer_server_ciphers off;
|
||||||
|
|
||||||
# HSTS
|
# HSTS
|
||||||
add_header Strict-Transport-Security "max-age=63072000" always;
|
add_header Strict-Transport-Security "max-age=63072000" always;
|
||||||
|
|
||||||
# 日志
|
# 日志
|
||||||
access_log /var/log/nginx/rwaapi.szaiai.com.access.log;
|
access_log /var/log/nginx/rwaapi.szaiai.com.access.log;
|
||||||
error_log /var/log/nginx/rwaapi.szaiai.com.error.log;
|
error_log /var/log/nginx/rwaapi.szaiai.com.error.log;
|
||||||
|
|
||||||
# Gzip 压缩
|
# Gzip 压缩
|
||||||
gzip on;
|
gzip on;
|
||||||
gzip_vary on;
|
gzip_vary on;
|
||||||
gzip_proxied any;
|
gzip_proxied any;
|
||||||
gzip_comp_level 6;
|
gzip_comp_level 6;
|
||||||
gzip_types text/plain text/css text/xml application/json application/javascript application/rss+xml application/atom+xml image/svg+xml;
|
gzip_types text/plain text/css text/xml application/json application/javascript application/rss+xml application/atom+xml image/svg+xml;
|
||||||
|
|
||||||
# 安全头
|
# 安全头
|
||||||
add_header X-Frame-Options "SAMEORIGIN" always;
|
add_header X-Frame-Options "SAMEORIGIN" always;
|
||||||
add_header X-Content-Type-Options "nosniff" always;
|
add_header X-Content-Type-Options "nosniff" always;
|
||||||
add_header X-XSS-Protection "1; mode=block" always;
|
add_header X-XSS-Protection "1; mode=block" always;
|
||||||
add_header Referrer-Policy "strict-origin-when-cross-origin" always;
|
add_header Referrer-Policy "strict-origin-when-cross-origin" always;
|
||||||
|
|
||||||
# 客户端请求大小限制 (500MB 用于 APK/IPA 上传)
|
# 客户端请求大小限制 (500MB 用于 APK/IPA 上传)
|
||||||
client_max_body_size 500M;
|
client_max_body_size 500M;
|
||||||
|
|
||||||
# 反向代理到 Kong API Gateway
|
# 反向代理到 Kong API Gateway
|
||||||
location / {
|
location / {
|
||||||
proxy_pass http://127.0.0.1:8000;
|
proxy_pass http://127.0.0.1:8000;
|
||||||
proxy_http_version 1.1;
|
proxy_http_version 1.1;
|
||||||
proxy_set_header Upgrade $http_upgrade;
|
proxy_set_header Upgrade $http_upgrade;
|
||||||
proxy_set_header Connection 'upgrade';
|
proxy_set_header Connection 'upgrade';
|
||||||
proxy_set_header Host $host;
|
proxy_set_header Host $host;
|
||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
proxy_set_header X-Forwarded-Host $host;
|
proxy_set_header X-Forwarded-Host $host;
|
||||||
proxy_set_header X-Forwarded-Port $server_port;
|
proxy_set_header X-Forwarded-Port $server_port;
|
||||||
proxy_cache_bypass $http_upgrade;
|
proxy_cache_bypass $http_upgrade;
|
||||||
|
|
||||||
# 超时设置 (适配大文件上传)
|
# 超时设置 (适配大文件上传)
|
||||||
proxy_connect_timeout 60s;
|
proxy_connect_timeout 60s;
|
||||||
proxy_send_timeout 300s;
|
proxy_send_timeout 300s;
|
||||||
proxy_read_timeout 300s;
|
proxy_read_timeout 300s;
|
||||||
|
|
||||||
# 缓冲设置
|
# 缓冲设置
|
||||||
proxy_buffering on;
|
proxy_buffering on;
|
||||||
proxy_buffer_size 128k;
|
proxy_buffer_size 128k;
|
||||||
proxy_buffers 4 256k;
|
proxy_buffers 4 256k;
|
||||||
proxy_busy_buffers_size 256k;
|
proxy_busy_buffers_size 256k;
|
||||||
}
|
}
|
||||||
|
|
||||||
# Kong Admin API (可选,仅内网访问)
|
# Kong Admin API (可选,仅内网访问)
|
||||||
# location /kong-admin/ {
|
# location /kong-admin/ {
|
||||||
# allow 127.0.0.1;
|
# allow 127.0.0.1;
|
||||||
# allow 10.0.0.0/8;
|
# allow 10.0.0.0/8;
|
||||||
# allow 172.16.0.0/12;
|
# allow 172.16.0.0/12;
|
||||||
# allow 192.168.0.0/16;
|
# allow 192.168.0.0/16;
|
||||||
# deny all;
|
# deny all;
|
||||||
# proxy_pass http://127.0.0.1:8001/;
|
# proxy_pass http://127.0.0.1:8001/;
|
||||||
# proxy_http_version 1.1;
|
# proxy_http_version 1.1;
|
||||||
# proxy_set_header Host $host;
|
# proxy_set_header Host $host;
|
||||||
# proxy_set_header X-Real-IP $remote_addr;
|
# proxy_set_header X-Real-IP $remote_addr;
|
||||||
# }
|
# }
|
||||||
|
|
||||||
# 健康检查端点 (直接返回)
|
# 健康检查端点 (直接返回)
|
||||||
location = /health {
|
location = /health {
|
||||||
access_log off;
|
access_log off;
|
||||||
return 200 '{"status":"ok","service":"rwaapi-nginx"}';
|
return 200 '{"status":"ok","service":"rwaapi-nginx"}';
|
||||||
add_header Content-Type application/json;
|
add_header Content-Type application/json;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,37 +1,37 @@
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Prometheus 配置 - Kong API Gateway + RWA Services 监控
|
# Prometheus 配置 - Kong API Gateway + RWA Services 监控
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
global:
|
global:
|
||||||
scrape_interval: 15s
|
scrape_interval: 15s
|
||||||
evaluation_interval: 15s
|
evaluation_interval: 15s
|
||||||
|
|
||||||
scrape_configs:
|
scrape_configs:
|
||||||
# Kong Prometheus 指标端点
|
# Kong Prometheus 指标端点
|
||||||
- job_name: 'kong'
|
- job_name: 'kong'
|
||||||
static_configs:
|
static_configs:
|
||||||
- targets: ['kong:8001']
|
- targets: ['kong:8001']
|
||||||
metrics_path: /metrics
|
metrics_path: /metrics
|
||||||
|
|
||||||
# Prometheus 自身监控
|
# Prometheus 自身监控
|
||||||
- job_name: 'prometheus'
|
- job_name: 'prometheus'
|
||||||
static_configs:
|
static_configs:
|
||||||
- targets: ['localhost:9090']
|
- targets: ['localhost:9090']
|
||||||
|
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
# RWA Presence Service - 用户活跃度与在线状态监控
|
# RWA Presence Service - 用户活跃度与在线状态监控
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
- job_name: 'presence-service'
|
- job_name: 'presence-service'
|
||||||
static_configs:
|
static_configs:
|
||||||
# 生产环境: 使用内网 IP 或 Docker 网络名称
|
# 生产环境: 使用内网 IP 或 Docker 网络名称
|
||||||
# - targets: ['presence-service:3011']
|
# - targets: ['presence-service:3011']
|
||||||
# 开发环境: 使用 host.docker.internal 访问宿主机服务
|
# 开发环境: 使用 host.docker.internal 访问宿主机服务
|
||||||
- targets: ['host.docker.internal:3011']
|
- targets: ['host.docker.internal:3011']
|
||||||
metrics_path: /api/v1/metrics
|
metrics_path: /api/v1/metrics
|
||||||
scrape_interval: 15s
|
scrape_interval: 15s
|
||||||
scrape_timeout: 10s
|
scrape_timeout: 10s
|
||||||
# 添加标签便于区分
|
# 添加标签便于区分
|
||||||
relabel_configs:
|
relabel_configs:
|
||||||
- source_labels: [__address__]
|
- source_labels: [__address__]
|
||||||
target_label: instance
|
target_label: instance
|
||||||
replacement: 'presence-service'
|
replacement: 'presence-service'
|
||||||
|
|
|
||||||
|
|
@ -1,380 +1,380 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Kong 监控栈一键安装脚本
|
# Kong 监控栈一键安装脚本
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 功能:
|
# 功能:
|
||||||
# - 自动配置 Nginx 反向代理
|
# - 自动配置 Nginx 反向代理
|
||||||
# - 自动申请 Let's Encrypt SSL 证书
|
# - 自动申请 Let's Encrypt SSL 证书
|
||||||
# - 启动 Prometheus + Grafana 监控服务
|
# - 启动 Prometheus + Grafana 监控服务
|
||||||
#
|
#
|
||||||
# 用法:
|
# 用法:
|
||||||
# ./install-monitor.sh # 使用默认域名 monitor.szaiai.com
|
# ./install-monitor.sh # 使用默认域名 monitor.szaiai.com
|
||||||
# ./install-monitor.sh mydomain.com # 使用自定义域名
|
# ./install-monitor.sh mydomain.com # 使用自定义域名
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
# 颜色定义
|
# 颜色定义
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
GREEN='\033[0;32m'
|
GREEN='\033[0;32m'
|
||||||
YELLOW='\033[1;33m'
|
YELLOW='\033[1;33m'
|
||||||
BLUE='\033[0;34m'
|
BLUE='\033[0;34m'
|
||||||
CYAN='\033[0;36m'
|
CYAN='\033[0;36m'
|
||||||
NC='\033[0m'
|
NC='\033[0m'
|
||||||
|
|
||||||
# 日志函数
|
# 日志函数
|
||||||
log_info() { echo -e "${BLUE}[INFO]${NC} $1"; }
|
log_info() { echo -e "${BLUE}[INFO]${NC} $1"; }
|
||||||
log_success() { echo -e "${GREEN}[OK]${NC} $1"; }
|
log_success() { echo -e "${GREEN}[OK]${NC} $1"; }
|
||||||
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
||||||
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
||||||
log_step() { echo -e "${CYAN}[STEP]${NC} $1"; }
|
log_step() { echo -e "${CYAN}[STEP]${NC} $1"; }
|
||||||
|
|
||||||
# 默认配置
|
# 默认配置
|
||||||
DOMAIN="${1:-monitor.szaiai.com}"
|
DOMAIN="${1:-monitor.szaiai.com}"
|
||||||
GRAFANA_PORT=3030
|
GRAFANA_PORT=3030
|
||||||
PROMETHEUS_PORT=9099
|
PROMETHEUS_PORT=9099
|
||||||
GRAFANA_USER="admin"
|
GRAFANA_USER="admin"
|
||||||
GRAFANA_PASS="admin123"
|
GRAFANA_PASS="admin123"
|
||||||
|
|
||||||
# 获取脚本目录
|
# 获取脚本目录
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
|
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||||
|
|
||||||
# 显示 Banner
|
# 显示 Banner
|
||||||
show_banner() {
|
show_banner() {
|
||||||
echo -e "${CYAN}"
|
echo -e "${CYAN}"
|
||||||
echo "╔═══════════════════════════════════════════════════════════════╗"
|
echo "╔═══════════════════════════════════════════════════════════════╗"
|
||||||
echo "║ Kong 监控栈一键安装脚本 ║"
|
echo "║ Kong 监控栈一键安装脚本 ║"
|
||||||
echo "║ Prometheus + Grafana ║"
|
echo "║ Prometheus + Grafana ║"
|
||||||
echo "╚═══════════════════════════════════════════════════════════════╝"
|
echo "╚═══════════════════════════════════════════════════════════════╝"
|
||||||
echo -e "${NC}"
|
echo -e "${NC}"
|
||||||
echo "域名: $DOMAIN"
|
echo "域名: $DOMAIN"
|
||||||
echo "Grafana 端口: $GRAFANA_PORT"
|
echo "Grafana 端口: $GRAFANA_PORT"
|
||||||
echo "Prometheus 端口: $PROMETHEUS_PORT"
|
echo "Prometheus 端口: $PROMETHEUS_PORT"
|
||||||
echo ""
|
echo ""
|
||||||
}
|
}
|
||||||
|
|
||||||
# 检查 root 权限
|
# 检查 root 权限
|
||||||
check_root() {
|
check_root() {
|
||||||
if [ "$EUID" -ne 0 ]; then
|
if [ "$EUID" -ne 0 ]; then
|
||||||
log_error "请使用 root 权限运行此脚本"
|
log_error "请使用 root 权限运行此脚本"
|
||||||
echo "用法: sudo $0 [domain]"
|
echo "用法: sudo $0 [domain]"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
# 检查依赖
|
# 检查依赖
|
||||||
check_dependencies() {
|
check_dependencies() {
|
||||||
log_step "检查依赖..."
|
log_step "检查依赖..."
|
||||||
|
|
||||||
local missing=()
|
local missing=()
|
||||||
|
|
||||||
if ! command -v docker &> /dev/null; then
|
if ! command -v docker &> /dev/null; then
|
||||||
missing+=("docker")
|
missing+=("docker")
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if ! command -v nginx &> /dev/null; then
|
if ! command -v nginx &> /dev/null; then
|
||||||
missing+=("nginx")
|
missing+=("nginx")
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if ! command -v certbot &> /dev/null; then
|
if ! command -v certbot &> /dev/null; then
|
||||||
missing+=("certbot")
|
missing+=("certbot")
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ${#missing[@]} -gt 0 ]; then
|
if [ ${#missing[@]} -gt 0 ]; then
|
||||||
log_error "缺少依赖: ${missing[*]}"
|
log_error "缺少依赖: ${missing[*]}"
|
||||||
echo ""
|
echo ""
|
||||||
echo "请先安装:"
|
echo "请先安装:"
|
||||||
echo " apt update && apt install -y docker.io nginx certbot python3-certbot-nginx"
|
echo " apt update && apt install -y docker.io nginx certbot python3-certbot-nginx"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
log_success "依赖检查通过"
|
log_success "依赖检查通过"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 检查 DNS 解析
|
# 检查 DNS 解析
|
||||||
check_dns() {
|
check_dns() {
|
||||||
log_step "检查 DNS 解析..."
|
log_step "检查 DNS 解析..."
|
||||||
|
|
||||||
local resolved_ip=$(dig +short $DOMAIN 2>/dev/null | head -1)
|
local resolved_ip=$(dig +short $DOMAIN 2>/dev/null | head -1)
|
||||||
local server_ip=$(curl -s ifconfig.me 2>/dev/null || curl -s icanhazip.com 2>/dev/null)
|
local server_ip=$(curl -s ifconfig.me 2>/dev/null || curl -s icanhazip.com 2>/dev/null)
|
||||||
|
|
||||||
if [ -z "$resolved_ip" ]; then
|
if [ -z "$resolved_ip" ]; then
|
||||||
log_error "无法解析域名 $DOMAIN"
|
log_error "无法解析域名 $DOMAIN"
|
||||||
echo "请先在 DNS 管理面板添加 A 记录:"
|
echo "请先在 DNS 管理面板添加 A 记录:"
|
||||||
echo " $DOMAIN -> $server_ip"
|
echo " $DOMAIN -> $server_ip"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$resolved_ip" != "$server_ip" ]; then
|
if [ "$resolved_ip" != "$server_ip" ]; then
|
||||||
log_warn "DNS 解析的 IP ($resolved_ip) 与本机公网 IP ($server_ip) 不匹配"
|
log_warn "DNS 解析的 IP ($resolved_ip) 与本机公网 IP ($server_ip) 不匹配"
|
||||||
read -p "是否继续? [y/N] " -n 1 -r
|
read -p "是否继续? [y/N] " -n 1 -r
|
||||||
echo
|
echo
|
||||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
log_success "DNS 解析正确: $DOMAIN -> $resolved_ip"
|
log_success "DNS 解析正确: $DOMAIN -> $resolved_ip"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 生成 Nginx 配置
|
# 生成 Nginx 配置
|
||||||
generate_nginx_config() {
|
generate_nginx_config() {
|
||||||
log_step "生成 Nginx 配置..."
|
log_step "生成 Nginx 配置..."
|
||||||
|
|
||||||
cat > /etc/nginx/sites-available/$DOMAIN.conf << EOF
|
cat > /etc/nginx/sites-available/$DOMAIN.conf << EOF
|
||||||
# Kong 监控面板 Nginx 配置
|
# Kong 监控面板 Nginx 配置
|
||||||
# 自动生成于 $(date)
|
# 自动生成于 $(date)
|
||||||
|
|
||||||
# HTTP -> HTTPS 重定向
|
# HTTP -> HTTPS 重定向
|
||||||
server {
|
server {
|
||||||
listen 80;
|
listen 80;
|
||||||
listen [::]:80;
|
listen [::]:80;
|
||||||
server_name $DOMAIN;
|
server_name $DOMAIN;
|
||||||
|
|
||||||
location /.well-known/acme-challenge/ {
|
location /.well-known/acme-challenge/ {
|
||||||
root /var/www/certbot;
|
root /var/www/certbot;
|
||||||
}
|
}
|
||||||
|
|
||||||
location / {
|
location / {
|
||||||
return 301 https://\$host\$request_uri;
|
return 301 https://\$host\$request_uri;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# HTTPS 配置
|
# HTTPS 配置
|
||||||
server {
|
server {
|
||||||
listen 443 ssl http2;
|
listen 443 ssl http2;
|
||||||
listen [::]:443 ssl http2;
|
listen [::]:443 ssl http2;
|
||||||
server_name $DOMAIN;
|
server_name $DOMAIN;
|
||||||
|
|
||||||
# SSL 证书 (Let's Encrypt)
|
# SSL 证书 (Let's Encrypt)
|
||||||
ssl_certificate /etc/letsencrypt/live/$DOMAIN/fullchain.pem;
|
ssl_certificate /etc/letsencrypt/live/$DOMAIN/fullchain.pem;
|
||||||
ssl_certificate_key /etc/letsencrypt/live/$DOMAIN/privkey.pem;
|
ssl_certificate_key /etc/letsencrypt/live/$DOMAIN/privkey.pem;
|
||||||
|
|
||||||
# SSL 优化
|
# SSL 优化
|
||||||
ssl_session_timeout 1d;
|
ssl_session_timeout 1d;
|
||||||
ssl_session_cache shared:SSL:50m;
|
ssl_session_cache shared:SSL:50m;
|
||||||
ssl_session_tickets off;
|
ssl_session_tickets off;
|
||||||
ssl_protocols TLSv1.2 TLSv1.3;
|
ssl_protocols TLSv1.2 TLSv1.3;
|
||||||
ssl_prefer_server_ciphers off;
|
ssl_prefer_server_ciphers off;
|
||||||
|
|
||||||
# HSTS
|
# HSTS
|
||||||
add_header Strict-Transport-Security "max-age=63072000" always;
|
add_header Strict-Transport-Security "max-age=63072000" always;
|
||||||
|
|
||||||
# 日志
|
# 日志
|
||||||
access_log /var/log/nginx/$DOMAIN.access.log;
|
access_log /var/log/nginx/$DOMAIN.access.log;
|
||||||
error_log /var/log/nginx/$DOMAIN.error.log;
|
error_log /var/log/nginx/$DOMAIN.error.log;
|
||||||
|
|
||||||
# Grafana
|
# Grafana
|
||||||
location / {
|
location / {
|
||||||
proxy_pass http://127.0.0.1:$GRAFANA_PORT;
|
proxy_pass http://127.0.0.1:$GRAFANA_PORT;
|
||||||
proxy_http_version 1.1;
|
proxy_http_version 1.1;
|
||||||
|
|
||||||
# WebSocket support
|
# WebSocket support
|
||||||
proxy_set_header Upgrade \$http_upgrade;
|
proxy_set_header Upgrade \$http_upgrade;
|
||||||
proxy_set_header Connection 'upgrade';
|
proxy_set_header Connection 'upgrade';
|
||||||
|
|
||||||
# Standard proxy headers
|
# Standard proxy headers
|
||||||
proxy_set_header Host \$http_host;
|
proxy_set_header Host \$http_host;
|
||||||
proxy_set_header X-Real-IP \$remote_addr;
|
proxy_set_header X-Real-IP \$remote_addr;
|
||||||
proxy_set_header X-Forwarded-For \$proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For \$proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto \$scheme;
|
proxy_set_header X-Forwarded-Proto \$scheme;
|
||||||
proxy_set_header X-Forwarded-Host \$host;
|
proxy_set_header X-Forwarded-Host \$host;
|
||||||
proxy_set_header X-Forwarded-Port \$server_port;
|
proxy_set_header X-Forwarded-Port \$server_port;
|
||||||
|
|
||||||
# Grafana 10+ 反向代理支持
|
# Grafana 10+ 反向代理支持
|
||||||
proxy_set_header Origin \$scheme://\$host;
|
proxy_set_header Origin \$scheme://\$host;
|
||||||
|
|
||||||
# 缓存和超时
|
# 缓存和超时
|
||||||
proxy_cache_bypass \$http_upgrade;
|
proxy_cache_bypass \$http_upgrade;
|
||||||
proxy_read_timeout 86400;
|
proxy_read_timeout 86400;
|
||||||
proxy_buffering off;
|
proxy_buffering off;
|
||||||
}
|
}
|
||||||
|
|
||||||
# Prometheus (仅内网)
|
# Prometheus (仅内网)
|
||||||
location /prometheus/ {
|
location /prometheus/ {
|
||||||
allow 127.0.0.1;
|
allow 127.0.0.1;
|
||||||
allow 10.0.0.0/8;
|
allow 10.0.0.0/8;
|
||||||
allow 172.16.0.0/12;
|
allow 172.16.0.0/12;
|
||||||
allow 192.168.0.0/16;
|
allow 192.168.0.0/16;
|
||||||
deny all;
|
deny all;
|
||||||
|
|
||||||
proxy_pass http://127.0.0.1:$PROMETHEUS_PORT/;
|
proxy_pass http://127.0.0.1:$PROMETHEUS_PORT/;
|
||||||
proxy_http_version 1.1;
|
proxy_http_version 1.1;
|
||||||
proxy_set_header Host \$host;
|
proxy_set_header Host \$host;
|
||||||
proxy_set_header X-Real-IP \$remote_addr;
|
proxy_set_header X-Real-IP \$remote_addr;
|
||||||
}
|
}
|
||||||
|
|
||||||
# 健康检查
|
# 健康检查
|
||||||
location = /health {
|
location = /health {
|
||||||
access_log off;
|
access_log off;
|
||||||
return 200 '{"status":"ok","service":"monitor-nginx"}';
|
return 200 '{"status":"ok","service":"monitor-nginx"}';
|
||||||
add_header Content-Type application/json;
|
add_header Content-Type application/json;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
log_success "Nginx 配置已生成: /etc/nginx/sites-available/$DOMAIN.conf"
|
log_success "Nginx 配置已生成: /etc/nginx/sites-available/$DOMAIN.conf"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 申请 SSL 证书
|
# 申请 SSL 证书
|
||||||
obtain_ssl_cert() {
|
obtain_ssl_cert() {
|
||||||
log_step "申请 SSL 证书..."
|
log_step "申请 SSL 证书..."
|
||||||
|
|
||||||
# 检查证书是否已存在
|
# 检查证书是否已存在
|
||||||
if [ -f "/etc/letsencrypt/live/$DOMAIN/fullchain.pem" ]; then
|
if [ -f "/etc/letsencrypt/live/$DOMAIN/fullchain.pem" ]; then
|
||||||
log_success "SSL 证书已存在"
|
log_success "SSL 证书已存在"
|
||||||
return 0
|
return 0
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# 创建 certbot webroot 目录
|
# 创建 certbot webroot 目录
|
||||||
mkdir -p /var/www/certbot
|
mkdir -p /var/www/certbot
|
||||||
|
|
||||||
# 临时启用 HTTP 配置用于验证
|
# 临时启用 HTTP 配置用于验证
|
||||||
cat > /etc/nginx/sites-available/$DOMAIN-temp.conf << EOF
|
cat > /etc/nginx/sites-available/$DOMAIN-temp.conf << EOF
|
||||||
server {
|
server {
|
||||||
listen 80;
|
listen 80;
|
||||||
server_name $DOMAIN;
|
server_name $DOMAIN;
|
||||||
|
|
||||||
location /.well-known/acme-challenge/ {
|
location /.well-known/acme-challenge/ {
|
||||||
root /var/www/certbot;
|
root /var/www/certbot;
|
||||||
}
|
}
|
||||||
|
|
||||||
location / {
|
location / {
|
||||||
return 200 'Waiting for SSL...';
|
return 200 'Waiting for SSL...';
|
||||||
add_header Content-Type text/plain;
|
add_header Content-Type text/plain;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
ln -sf /etc/nginx/sites-available/$DOMAIN-temp.conf /etc/nginx/sites-enabled/
|
ln -sf /etc/nginx/sites-available/$DOMAIN-temp.conf /etc/nginx/sites-enabled/
|
||||||
nginx -t && systemctl reload nginx
|
nginx -t && systemctl reload nginx
|
||||||
|
|
||||||
# 申请证书
|
# 申请证书
|
||||||
certbot certonly --webroot -w /var/www/certbot -d $DOMAIN --non-interactive --agree-tos --email admin@$DOMAIN || {
|
certbot certonly --webroot -w /var/www/certbot -d $DOMAIN --non-interactive --agree-tos --email admin@$DOMAIN || {
|
||||||
log_error "SSL 证书申请失败"
|
log_error "SSL 证书申请失败"
|
||||||
rm -f /etc/nginx/sites-enabled/$DOMAIN-temp.conf
|
rm -f /etc/nginx/sites-enabled/$DOMAIN-temp.conf
|
||||||
rm -f /etc/nginx/sites-available/$DOMAIN-temp.conf
|
rm -f /etc/nginx/sites-available/$DOMAIN-temp.conf
|
||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
|
|
||||||
# 清理临时配置
|
# 清理临时配置
|
||||||
rm -f /etc/nginx/sites-enabled/$DOMAIN-temp.conf
|
rm -f /etc/nginx/sites-enabled/$DOMAIN-temp.conf
|
||||||
rm -f /etc/nginx/sites-available/$DOMAIN-temp.conf
|
rm -f /etc/nginx/sites-available/$DOMAIN-temp.conf
|
||||||
|
|
||||||
log_success "SSL 证书申请成功"
|
log_success "SSL 证书申请成功"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 启用 Nginx 配置
|
# 启用 Nginx 配置
|
||||||
enable_nginx_config() {
|
enable_nginx_config() {
|
||||||
log_step "启用 Nginx 配置..."
|
log_step "启用 Nginx 配置..."
|
||||||
|
|
||||||
ln -sf /etc/nginx/sites-available/$DOMAIN.conf /etc/nginx/sites-enabled/
|
ln -sf /etc/nginx/sites-available/$DOMAIN.conf /etc/nginx/sites-enabled/
|
||||||
|
|
||||||
nginx -t || {
|
nginx -t || {
|
||||||
log_error "Nginx 配置测试失败"
|
log_error "Nginx 配置测试失败"
|
||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
|
|
||||||
systemctl reload nginx
|
systemctl reload nginx
|
||||||
log_success "Nginx 配置已启用"
|
log_success "Nginx 配置已启用"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 启动监控服务
|
# 启动监控服务
|
||||||
start_monitoring_services() {
|
start_monitoring_services() {
|
||||||
log_step "启动监控服务..."
|
log_step "启动监控服务..."
|
||||||
|
|
||||||
cd "$PROJECT_DIR"
|
cd "$PROJECT_DIR"
|
||||||
|
|
||||||
# 检查 Kong 是否运行
|
# 检查 Kong 是否运行
|
||||||
if ! docker ps | grep -q rwa-kong; then
|
if ! docker ps | grep -q rwa-kong; then
|
||||||
log_warn "Kong 未运行,先启动 Kong..."
|
log_warn "Kong 未运行,先启动 Kong..."
|
||||||
docker compose up -d
|
docker compose up -d
|
||||||
sleep 10
|
sleep 10
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# 同步 Kong 配置 (启用 prometheus 插件)
|
# 同步 Kong 配置 (启用 prometheus 插件)
|
||||||
log_info "同步 Kong 配置..."
|
log_info "同步 Kong 配置..."
|
||||||
docker compose run --rm kong-config || log_warn "配置同步失败,可能已是最新"
|
docker compose run --rm kong-config || log_warn "配置同步失败,可能已是最新"
|
||||||
|
|
||||||
# 启动监控栈
|
# 启动监控栈
|
||||||
log_info "启动 Prometheus + Grafana..."
|
log_info "启动 Prometheus + Grafana..."
|
||||||
docker compose -f docker-compose.yml -f docker-compose.monitoring.yml up -d prometheus grafana
|
docker compose -f docker-compose.yml -f docker-compose.monitoring.yml up -d prometheus grafana
|
||||||
|
|
||||||
# 等待服务启动
|
# 等待服务启动
|
||||||
sleep 5
|
sleep 5
|
||||||
|
|
||||||
# 检查服务状态
|
# 检查服务状态
|
||||||
if docker ps | grep -q rwa-grafana && docker ps | grep -q rwa-prometheus; then
|
if docker ps | grep -q rwa-grafana && docker ps | grep -q rwa-prometheus; then
|
||||||
log_success "监控服务启动成功"
|
log_success "监控服务启动成功"
|
||||||
else
|
else
|
||||||
log_error "监控服务启动失败"
|
log_error "监控服务启动失败"
|
||||||
docker compose -f docker-compose.yml -f docker-compose.monitoring.yml logs --tail=50
|
docker compose -f docker-compose.yml -f docker-compose.monitoring.yml logs --tail=50
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
# 显示安装结果
|
# 显示安装结果
|
||||||
show_result() {
|
show_result() {
|
||||||
echo ""
|
echo ""
|
||||||
echo -e "${GREEN}╔═══════════════════════════════════════════════════════════════╗${NC}"
|
echo -e "${GREEN}╔═══════════════════════════════════════════════════════════════╗${NC}"
|
||||||
echo -e "${GREEN}║ 安装完成! ║${NC}"
|
echo -e "${GREEN}║ 安装完成! ║${NC}"
|
||||||
echo -e "${GREEN}╚═══════════════════════════════════════════════════════════════╝${NC}"
|
echo -e "${GREEN}╚═══════════════════════════════════════════════════════════════╝${NC}"
|
||||||
echo ""
|
echo ""
|
||||||
echo "访问地址:"
|
echo "访问地址:"
|
||||||
echo -e " Grafana: ${CYAN}https://$DOMAIN${NC}"
|
echo -e " Grafana: ${CYAN}https://$DOMAIN${NC}"
|
||||||
echo -e " 用户名: ${YELLOW}$GRAFANA_USER${NC}"
|
echo -e " 用户名: ${YELLOW}$GRAFANA_USER${NC}"
|
||||||
echo -e " 密码: ${YELLOW}$GRAFANA_PASS${NC}"
|
echo -e " 密码: ${YELLOW}$GRAFANA_PASS${NC}"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Prometheus (仅内网可访问):"
|
echo "Prometheus (仅内网可访问):"
|
||||||
echo -e " 地址: ${CYAN}https://$DOMAIN/prometheus/${NC}"
|
echo -e " 地址: ${CYAN}https://$DOMAIN/prometheus/${NC}"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Kong 指标端点:"
|
echo "Kong 指标端点:"
|
||||||
echo -e " 地址: ${CYAN}http://localhost:8001/metrics${NC}"
|
echo -e " 地址: ${CYAN}http://localhost:8001/metrics${NC}"
|
||||||
echo ""
|
echo ""
|
||||||
echo "管理命令:"
|
echo "管理命令:"
|
||||||
echo " ./deploy.sh monitoring up # 启动监控"
|
echo " ./deploy.sh monitoring up # 启动监控"
|
||||||
echo " ./deploy.sh monitoring down # 停止监控"
|
echo " ./deploy.sh monitoring down # 停止监控"
|
||||||
echo " ./deploy.sh metrics # 查看指标"
|
echo " ./deploy.sh metrics # 查看指标"
|
||||||
echo ""
|
echo ""
|
||||||
}
|
}
|
||||||
|
|
||||||
# 卸载函数
|
# 卸载函数
|
||||||
uninstall() {
|
uninstall() {
|
||||||
log_warn "正在卸载监控栈..."
|
log_warn "正在卸载监控栈..."
|
||||||
|
|
||||||
# 停止服务
|
# 停止服务
|
||||||
cd "$PROJECT_DIR"
|
cd "$PROJECT_DIR"
|
||||||
docker stop rwa-prometheus rwa-grafana 2>/dev/null || true
|
docker stop rwa-prometheus rwa-grafana 2>/dev/null || true
|
||||||
docker rm rwa-prometheus rwa-grafana 2>/dev/null || true
|
docker rm rwa-prometheus rwa-grafana 2>/dev/null || true
|
||||||
|
|
||||||
# 删除 Nginx 配置
|
# 删除 Nginx 配置
|
||||||
rm -f /etc/nginx/sites-enabled/$DOMAIN.conf
|
rm -f /etc/nginx/sites-enabled/$DOMAIN.conf
|
||||||
rm -f /etc/nginx/sites-available/$DOMAIN.conf
|
rm -f /etc/nginx/sites-available/$DOMAIN.conf
|
||||||
systemctl reload nginx 2>/dev/null || true
|
systemctl reload nginx 2>/dev/null || true
|
||||||
|
|
||||||
log_success "监控栈已卸载"
|
log_success "监控栈已卸载"
|
||||||
echo "注意: SSL 证书未删除,如需删除请运行: certbot delete --cert-name $DOMAIN"
|
echo "注意: SSL 证书未删除,如需删除请运行: certbot delete --cert-name $DOMAIN"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 主函数
|
# 主函数
|
||||||
main() {
|
main() {
|
||||||
show_banner
|
show_banner
|
||||||
|
|
||||||
# 检查是否卸载
|
# 检查是否卸载
|
||||||
if [ "$1" = "uninstall" ] || [ "$1" = "--uninstall" ]; then
|
if [ "$1" = "uninstall" ] || [ "$1" = "--uninstall" ]; then
|
||||||
uninstall
|
uninstall
|
||||||
exit 0
|
exit 0
|
||||||
fi
|
fi
|
||||||
|
|
||||||
check_root
|
check_root
|
||||||
check_dependencies
|
check_dependencies
|
||||||
check_dns
|
check_dns
|
||||||
generate_nginx_config
|
generate_nginx_config
|
||||||
obtain_ssl_cert
|
obtain_ssl_cert
|
||||||
enable_nginx_config
|
enable_nginx_config
|
||||||
start_monitoring_services
|
start_monitoring_services
|
||||||
show_result
|
show_result
|
||||||
}
|
}
|
||||||
|
|
||||||
main "$@"
|
main "$@"
|
||||||
|
|
|
||||||
|
|
@ -1,31 +1,31 @@
|
||||||
{
|
{
|
||||||
"permissions": {
|
"permissions": {
|
||||||
"allow": [
|
"allow": [
|
||||||
"Bash(dir:*)",
|
"Bash(dir:*)",
|
||||||
"Bash(go mod tidy:*)",
|
"Bash(go mod tidy:*)",
|
||||||
"Bash(cat:*)",
|
"Bash(cat:*)",
|
||||||
"Bash(go build:*)",
|
"Bash(go build:*)",
|
||||||
"Bash(go test:*)",
|
"Bash(go test:*)",
|
||||||
"Bash(go tool cover:*)",
|
"Bash(go tool cover:*)",
|
||||||
"Bash(wsl -e bash -c \"docker --version && docker-compose --version\")",
|
"Bash(wsl -e bash -c \"docker --version && docker-compose --version\")",
|
||||||
"Bash(wsl -e bash -c:*)",
|
"Bash(wsl -e bash -c:*)",
|
||||||
"Bash(timeout 180 bash -c 'while true; do status=$(wsl -e bash -c \"\"which docker 2>/dev/null\"\"); if [ -n \"\"$status\"\" ]; then echo \"\"Docker installed\"\"; break; fi; sleep 5; done')",
|
"Bash(timeout 180 bash -c 'while true; do status=$(wsl -e bash -c \"\"which docker 2>/dev/null\"\"); if [ -n \"\"$status\"\" ]; then echo \"\"Docker installed\"\"; break; fi; sleep 5; done')",
|
||||||
"Bash(docker --version:*)",
|
"Bash(docker --version:*)",
|
||||||
"Bash(powershell -c:*)",
|
"Bash(powershell -c:*)",
|
||||||
"Bash(go version:*)",
|
"Bash(go version:*)",
|
||||||
"Bash(set TEST_DATABASE_URL=postgres://mpc_user:mpc_password@localhost:5433/mpc_system_test?sslmode=disable:*)",
|
"Bash(set TEST_DATABASE_URL=postgres://mpc_user:mpc_password@localhost:5433/mpc_system_test?sslmode=disable:*)",
|
||||||
"Bash(Select-String -Pattern \"PASS|FAIL|RUN\")",
|
"Bash(Select-String -Pattern \"PASS|FAIL|RUN\")",
|
||||||
"Bash(Select-Object -Last 30)",
|
"Bash(Select-Object -Last 30)",
|
||||||
"Bash(Select-String -Pattern \"grpc_handler.go\")",
|
"Bash(Select-String -Pattern \"grpc_handler.go\")",
|
||||||
"Bash(Select-Object -First 10)",
|
"Bash(Select-Object -First 10)",
|
||||||
"Bash(git add:*)",
|
"Bash(git add:*)",
|
||||||
"Bash(git commit:*)",
|
"Bash(git commit:*)",
|
||||||
"Bash(where:*)",
|
"Bash(where:*)",
|
||||||
"Bash(go get:*)",
|
"Bash(go get:*)",
|
||||||
"Bash(findstr:*)",
|
"Bash(findstr:*)",
|
||||||
"Bash(git push)"
|
"Bash(git push)"
|
||||||
],
|
],
|
||||||
"deny": [],
|
"deny": [],
|
||||||
"ask": []
|
"ask": []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,93 +1,93 @@
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# MPC System - Environment Configuration
|
# MPC System - Environment Configuration
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# This file contains all environment variables needed for MPC System deployment.
|
# This file contains all environment variables needed for MPC System deployment.
|
||||||
#
|
#
|
||||||
# Setup Instructions:
|
# Setup Instructions:
|
||||||
# 1. Copy this file: cp .env.example .env
|
# 1. Copy this file: cp .env.example .env
|
||||||
# 2. Update ALL values according to your production environment
|
# 2. Update ALL values according to your production environment
|
||||||
# 3. Generate secure random keys for secrets (see instructions below)
|
# 3. Generate secure random keys for secrets (see instructions below)
|
||||||
# 4. Start services: ./deploy.sh up
|
# 4. Start services: ./deploy.sh up
|
||||||
#
|
#
|
||||||
# IMPORTANT: This file contains examples only!
|
# IMPORTANT: This file contains examples only!
|
||||||
# In production, you MUST:
|
# In production, you MUST:
|
||||||
# - Change ALL passwords and keys to secure random values
|
# - Change ALL passwords and keys to secure random values
|
||||||
# - Update ALLOWED_IPS to match your actual backend server IP
|
# - Update ALLOWED_IPS to match your actual backend server IP
|
||||||
# - Keep the .env file secure and NEVER commit it to version control
|
# - Keep the .env file secure and NEVER commit it to version control
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Environment Identifier
|
# Environment Identifier
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Options: development, staging, production
|
# Options: development, staging, production
|
||||||
ENVIRONMENT=production
|
ENVIRONMENT=production
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# PostgreSQL Database Configuration
|
# PostgreSQL Database Configuration
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Database user (can keep default or customize)
|
# Database user (can keep default or customize)
|
||||||
POSTGRES_USER=mpc_user
|
POSTGRES_USER=mpc_user
|
||||||
|
|
||||||
# Database password
|
# Database password
|
||||||
# SECURITY: Generate a strong password in production!
|
# SECURITY: Generate a strong password in production!
|
||||||
# Example command: openssl rand -base64 32
|
# Example command: openssl rand -base64 32
|
||||||
POSTGRES_PASSWORD=change_this_to_secure_postgres_password
|
POSTGRES_PASSWORD=change_this_to_secure_postgres_password
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Redis Cache Configuration
|
# Redis Cache Configuration
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Redis password (leave empty if Redis is only accessible within Docker network)
|
# Redis password (leave empty if Redis is only accessible within Docker network)
|
||||||
# For production, consider setting a password for defense in depth
|
# For production, consider setting a password for defense in depth
|
||||||
# Example command: openssl rand -base64 24
|
# Example command: openssl rand -base64 24
|
||||||
REDIS_PASSWORD=
|
REDIS_PASSWORD=
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# RabbitMQ Message Broker Configuration
|
# RabbitMQ Message Broker Configuration
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# RabbitMQ user (can keep default or customize)
|
# RabbitMQ user (can keep default or customize)
|
||||||
RABBITMQ_USER=mpc_user
|
RABBITMQ_USER=mpc_user
|
||||||
|
|
||||||
# RabbitMQ password
|
# RabbitMQ password
|
||||||
# SECURITY: Generate a strong password in production!
|
# SECURITY: Generate a strong password in production!
|
||||||
# Example command: openssl rand -base64 32
|
# Example command: openssl rand -base64 32
|
||||||
RABBITMQ_PASSWORD=change_this_to_secure_rabbitmq_password
|
RABBITMQ_PASSWORD=change_this_to_secure_rabbitmq_password
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# JWT Configuration
|
# JWT Configuration
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# JWT signing secret key (minimum 32 characters)
|
# JWT signing secret key (minimum 32 characters)
|
||||||
# SECURITY: Generate a strong random key in production!
|
# SECURITY: Generate a strong random key in production!
|
||||||
# Example command: openssl rand -base64 48
|
# Example command: openssl rand -base64 48
|
||||||
JWT_SECRET_KEY=change_this_jwt_secret_key_to_random_value_min_32_chars
|
JWT_SECRET_KEY=change_this_jwt_secret_key_to_random_value_min_32_chars
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Cryptography Configuration
|
# Cryptography Configuration
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Master encryption key for encrypting stored key shares
|
# Master encryption key for encrypting stored key shares
|
||||||
# MUST be exactly 64 hexadecimal characters (256-bit key)
|
# MUST be exactly 64 hexadecimal characters (256-bit key)
|
||||||
# SECURITY: Generate a secure random key in production!
|
# SECURITY: Generate a secure random key in production!
|
||||||
# Example command: openssl rand -hex 32
|
# Example command: openssl rand -hex 32
|
||||||
# WARNING: If you lose this key, encrypted shares cannot be recovered!
|
# WARNING: If you lose this key, encrypted shares cannot be recovered!
|
||||||
CRYPTO_MASTER_KEY=0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef
|
CRYPTO_MASTER_KEY=0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# API Security Configuration
|
# API Security Configuration
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# API authentication key for server-to-server communication
|
# API authentication key for server-to-server communication
|
||||||
# This key must match the MPC_API_KEY in your backend mpc-service configuration
|
# This key must match the MPC_API_KEY in your backend mpc-service configuration
|
||||||
# SECURITY: Generate a strong random key and keep it synchronized!
|
# SECURITY: Generate a strong random key and keep it synchronized!
|
||||||
# Example command: openssl rand -base64 48
|
# Example command: openssl rand -base64 48
|
||||||
MPC_API_KEY=change_this_api_key_to_match_your_mpc_service_config
|
MPC_API_KEY=change_this_api_key_to_match_your_mpc_service_config
|
||||||
|
|
||||||
# Allowed IP addresses (comma-separated list)
|
# Allowed IP addresses (comma-separated list)
|
||||||
# Only these IPs can access the MPC system APIs
|
# Only these IPs can access the MPC system APIs
|
||||||
# IMPORTANT: In production, restrict this to your actual backend server IP(s)!
|
# IMPORTANT: In production, restrict this to your actual backend server IP(s)!
|
||||||
# Examples:
|
# Examples:
|
||||||
# Single IP: ALLOWED_IPS=192.168.1.111
|
# Single IP: ALLOWED_IPS=192.168.1.111
|
||||||
# Multiple IPs: ALLOWED_IPS=192.168.1.111,192.168.1.112
|
# Multiple IPs: ALLOWED_IPS=192.168.1.111,192.168.1.112
|
||||||
# Local only: ALLOWED_IPS=127.0.0.1
|
# Local only: ALLOWED_IPS=127.0.0.1
|
||||||
# Allow all: ALLOWED_IPS= (empty, relies on API_KEY auth only - NOT RECOMMENDED for production)
|
# Allow all: ALLOWED_IPS= (empty, relies on API_KEY auth only - NOT RECOMMENDED for production)
|
||||||
#
|
#
|
||||||
# Default allows all IPs (protected by API_KEY authentication)
|
# Default allows all IPs (protected by API_KEY authentication)
|
||||||
# SECURITY WARNING: Change this in production to specific backend server IP(s)!
|
# SECURITY WARNING: Change this in production to specific backend server IP(s)!
|
||||||
ALLOWED_IPS=
|
ALLOWED_IPS=
|
||||||
|
|
|
||||||
|
|
@ -1,35 +1,35 @@
|
||||||
# Environment files (contain secrets)
|
# Environment files (contain secrets)
|
||||||
.env
|
.env
|
||||||
.env.local
|
.env.local
|
||||||
.env.production
|
.env.production
|
||||||
|
|
||||||
# Build artifacts
|
# Build artifacts
|
||||||
/bin/
|
/bin/
|
||||||
*.exe
|
*.exe
|
||||||
*.dll
|
*.dll
|
||||||
*.so
|
*.so
|
||||||
*.dylib
|
*.dylib
|
||||||
|
|
||||||
# Test binary
|
# Test binary
|
||||||
*.test
|
*.test
|
||||||
|
|
||||||
# Output of go coverage
|
# Output of go coverage
|
||||||
*.out
|
*.out
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
.idea/
|
.idea/
|
||||||
.vscode/
|
.vscode/
|
||||||
*.swp
|
*.swp
|
||||||
*.swo
|
*.swo
|
||||||
|
|
||||||
# OS files
|
# OS files
|
||||||
.DS_Store
|
.DS_Store
|
||||||
Thumbs.db
|
Thumbs.db
|
||||||
|
|
||||||
# Logs
|
# Logs
|
||||||
*.log
|
*.log
|
||||||
logs/
|
logs/
|
||||||
|
|
||||||
# Temporary files
|
# Temporary files
|
||||||
tmp/
|
tmp/
|
||||||
temp/
|
temp/
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,260 +1,260 @@
|
||||||
.PHONY: help proto build test docker-build docker-up docker-down deploy-k8s clean lint fmt
|
.PHONY: help proto build test docker-build docker-up docker-down deploy-k8s clean lint fmt
|
||||||
|
|
||||||
# Default target
|
# Default target
|
||||||
.DEFAULT_GOAL := help
|
.DEFAULT_GOAL := help
|
||||||
|
|
||||||
# Variables
|
# Variables
|
||||||
GO := go
|
GO := go
|
||||||
DOCKER := docker
|
DOCKER := docker
|
||||||
DOCKER_COMPOSE := docker-compose
|
DOCKER_COMPOSE := docker-compose
|
||||||
PROTOC := protoc
|
PROTOC := protoc
|
||||||
GOPATH := $(shell go env GOPATH)
|
GOPATH := $(shell go env GOPATH)
|
||||||
PROJECT_NAME := mpc-system
|
PROJECT_NAME := mpc-system
|
||||||
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||||
BUILD_TIME := $(shell date -u '+%Y-%m-%d_%H:%M:%S')
|
BUILD_TIME := $(shell date -u '+%Y-%m-%d_%H:%M:%S')
|
||||||
LDFLAGS := -ldflags "-X main.Version=$(VERSION) -X main.BuildTime=$(BUILD_TIME)"
|
LDFLAGS := -ldflags "-X main.Version=$(VERSION) -X main.BuildTime=$(BUILD_TIME)"
|
||||||
|
|
||||||
# Services
|
# Services
|
||||||
SERVICES := session-coordinator message-router server-party account
|
SERVICES := session-coordinator message-router server-party account
|
||||||
|
|
||||||
help: ## Show this help
|
help: ## Show this help
|
||||||
@echo "MPC Distributed Signature System - Build Commands"
|
@echo "MPC Distributed Signature System - Build Commands"
|
||||||
@echo ""
|
@echo ""
|
||||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
|
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Development Commands
|
# Development Commands
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
init: ## Initialize the project (install tools)
|
init: ## Initialize the project (install tools)
|
||||||
@echo "Installing tools..."
|
@echo "Installing tools..."
|
||||||
$(GO) install google.golang.org/protobuf/cmd/protoc-gen-go@latest
|
$(GO) install google.golang.org/protobuf/cmd/protoc-gen-go@latest
|
||||||
$(GO) install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest
|
$(GO) install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest
|
||||||
$(GO) install github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-grpc-gateway@latest
|
$(GO) install github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-grpc-gateway@latest
|
||||||
$(GO) install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
$(GO) install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||||
$(GO) mod download
|
$(GO) mod download
|
||||||
@echo "Tools installed successfully!"
|
@echo "Tools installed successfully!"
|
||||||
|
|
||||||
proto: ## Generate protobuf code
|
proto: ## Generate protobuf code
|
||||||
@echo "Generating protobuf..."
|
@echo "Generating protobuf..."
|
||||||
@mkdir -p api/grpc/coordinator/v1
|
@mkdir -p api/grpc/coordinator/v1
|
||||||
@mkdir -p api/grpc/router/v1
|
@mkdir -p api/grpc/router/v1
|
||||||
$(PROTOC) --go_out=. --go_opt=paths=source_relative \
|
$(PROTOC) --go_out=. --go_opt=paths=source_relative \
|
||||||
--go-grpc_out=. --go-grpc_opt=paths=source_relative \
|
--go-grpc_out=. --go-grpc_opt=paths=source_relative \
|
||||||
api/proto/*.proto
|
api/proto/*.proto
|
||||||
@echo "Protobuf generated successfully!"
|
@echo "Protobuf generated successfully!"
|
||||||
|
|
||||||
fmt: ## Format Go code
|
fmt: ## Format Go code
|
||||||
@echo "Formatting code..."
|
@echo "Formatting code..."
|
||||||
$(GO) fmt ./...
|
$(GO) fmt ./...
|
||||||
@echo "Code formatted!"
|
@echo "Code formatted!"
|
||||||
|
|
||||||
lint: ## Run linter
|
lint: ## Run linter
|
||||||
@echo "Running linter..."
|
@echo "Running linter..."
|
||||||
golangci-lint run ./...
|
golangci-lint run ./...
|
||||||
@echo "Lint completed!"
|
@echo "Lint completed!"
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Build Commands
|
# Build Commands
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
build: ## Build all services
|
build: ## Build all services
|
||||||
@echo "Building all services..."
|
@echo "Building all services..."
|
||||||
@for service in $(SERVICES); do \
|
@for service in $(SERVICES); do \
|
||||||
echo "Building $$service..."; \
|
echo "Building $$service..."; \
|
||||||
$(GO) build $(LDFLAGS) -o bin/$$service ./services/$$service/cmd/server; \
|
$(GO) build $(LDFLAGS) -o bin/$$service ./services/$$service/cmd/server; \
|
||||||
done
|
done
|
||||||
@echo "All services built successfully!"
|
@echo "All services built successfully!"
|
||||||
|
|
||||||
build-session-coordinator: ## Build session-coordinator service
|
build-session-coordinator: ## Build session-coordinator service
|
||||||
@echo "Building session-coordinator..."
|
@echo "Building session-coordinator..."
|
||||||
$(GO) build $(LDFLAGS) -o bin/session-coordinator ./services/session-coordinator/cmd/server
|
$(GO) build $(LDFLAGS) -o bin/session-coordinator ./services/session-coordinator/cmd/server
|
||||||
|
|
||||||
build-message-router: ## Build message-router service
|
build-message-router: ## Build message-router service
|
||||||
@echo "Building message-router..."
|
@echo "Building message-router..."
|
||||||
$(GO) build $(LDFLAGS) -o bin/message-router ./services/message-router/cmd/server
|
$(GO) build $(LDFLAGS) -o bin/message-router ./services/message-router/cmd/server
|
||||||
|
|
||||||
build-server-party: ## Build server-party service
|
build-server-party: ## Build server-party service
|
||||||
@echo "Building server-party..."
|
@echo "Building server-party..."
|
||||||
$(GO) build $(LDFLAGS) -o bin/server-party ./services/server-party/cmd/server
|
$(GO) build $(LDFLAGS) -o bin/server-party ./services/server-party/cmd/server
|
||||||
|
|
||||||
build-account: ## Build account service
|
build-account: ## Build account service
|
||||||
@echo "Building account service..."
|
@echo "Building account service..."
|
||||||
$(GO) build $(LDFLAGS) -o bin/account ./services/account/cmd/server
|
$(GO) build $(LDFLAGS) -o bin/account ./services/account/cmd/server
|
||||||
|
|
||||||
clean: ## Clean build artifacts
|
clean: ## Clean build artifacts
|
||||||
@echo "Cleaning..."
|
@echo "Cleaning..."
|
||||||
rm -rf bin/
|
rm -rf bin/
|
||||||
rm -rf vendor/
|
rm -rf vendor/
|
||||||
$(GO) clean -cache
|
$(GO) clean -cache
|
||||||
@echo "Cleaned!"
|
@echo "Cleaned!"
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Test Commands
|
# Test Commands
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
test: ## Run all tests
|
test: ## Run all tests
|
||||||
@echo "Running tests..."
|
@echo "Running tests..."
|
||||||
$(GO) test -v -race -coverprofile=coverage.out ./...
|
$(GO) test -v -race -coverprofile=coverage.out ./...
|
||||||
@echo "Tests completed!"
|
@echo "Tests completed!"
|
||||||
|
|
||||||
test-unit: ## Run unit tests only
|
test-unit: ## Run unit tests only
|
||||||
@echo "Running unit tests..."
|
@echo "Running unit tests..."
|
||||||
$(GO) test -v -race -short ./...
|
$(GO) test -v -race -short ./...
|
||||||
@echo "Unit tests completed!"
|
@echo "Unit tests completed!"
|
||||||
|
|
||||||
test-integration: ## Run integration tests
|
test-integration: ## Run integration tests
|
||||||
@echo "Running integration tests..."
|
@echo "Running integration tests..."
|
||||||
$(GO) test -v -race -tags=integration ./tests/integration/...
|
$(GO) test -v -race -tags=integration ./tests/integration/...
|
||||||
@echo "Integration tests completed!"
|
@echo "Integration tests completed!"
|
||||||
|
|
||||||
test-e2e: ## Run end-to-end tests
|
test-e2e: ## Run end-to-end tests
|
||||||
@echo "Running e2e tests..."
|
@echo "Running e2e tests..."
|
||||||
$(GO) test -v -race -tags=e2e ./tests/e2e/...
|
$(GO) test -v -race -tags=e2e ./tests/e2e/...
|
||||||
@echo "E2E tests completed!"
|
@echo "E2E tests completed!"
|
||||||
|
|
||||||
test-coverage: ## Run tests with coverage report
|
test-coverage: ## Run tests with coverage report
|
||||||
@echo "Running tests with coverage..."
|
@echo "Running tests with coverage..."
|
||||||
$(GO) test -v -race -coverprofile=coverage.out -covermode=atomic ./...
|
$(GO) test -v -race -coverprofile=coverage.out -covermode=atomic ./...
|
||||||
$(GO) tool cover -html=coverage.out -o coverage.html
|
$(GO) tool cover -html=coverage.out -o coverage.html
|
||||||
@echo "Coverage report generated: coverage.html"
|
@echo "Coverage report generated: coverage.html"
|
||||||
|
|
||||||
test-docker-integration: ## Run integration tests in Docker
|
test-docker-integration: ## Run integration tests in Docker
|
||||||
@echo "Starting test infrastructure..."
|
@echo "Starting test infrastructure..."
|
||||||
$(DOCKER_COMPOSE) -f tests/docker-compose.test.yml up -d postgres-test redis-test rabbitmq-test
|
$(DOCKER_COMPOSE) -f tests/docker-compose.test.yml up -d postgres-test redis-test rabbitmq-test
|
||||||
@echo "Waiting for services..."
|
@echo "Waiting for services..."
|
||||||
sleep 10
|
sleep 10
|
||||||
$(DOCKER_COMPOSE) -f tests/docker-compose.test.yml run --rm migrate
|
$(DOCKER_COMPOSE) -f tests/docker-compose.test.yml run --rm migrate
|
||||||
$(DOCKER_COMPOSE) -f tests/docker-compose.test.yml run --rm integration-tests
|
$(DOCKER_COMPOSE) -f tests/docker-compose.test.yml run --rm integration-tests
|
||||||
@echo "Integration tests completed!"
|
@echo "Integration tests completed!"
|
||||||
|
|
||||||
test-docker-e2e: ## Run E2E tests in Docker
|
test-docker-e2e: ## Run E2E tests in Docker
|
||||||
@echo "Starting full test environment..."
|
@echo "Starting full test environment..."
|
||||||
$(DOCKER_COMPOSE) -f tests/docker-compose.test.yml up -d
|
$(DOCKER_COMPOSE) -f tests/docker-compose.test.yml up -d
|
||||||
@echo "Waiting for services to be healthy..."
|
@echo "Waiting for services to be healthy..."
|
||||||
sleep 30
|
sleep 30
|
||||||
$(DOCKER_COMPOSE) -f tests/docker-compose.test.yml run --rm e2e-tests
|
$(DOCKER_COMPOSE) -f tests/docker-compose.test.yml run --rm e2e-tests
|
||||||
@echo "E2E tests completed!"
|
@echo "E2E tests completed!"
|
||||||
|
|
||||||
test-docker-all: ## Run all tests in Docker
|
test-docker-all: ## Run all tests in Docker
|
||||||
@echo "Running all tests in Docker..."
|
@echo "Running all tests in Docker..."
|
||||||
$(MAKE) test-docker-integration
|
$(MAKE) test-docker-integration
|
||||||
$(MAKE) test-docker-e2e
|
$(MAKE) test-docker-e2e
|
||||||
@echo "All Docker tests completed!"
|
@echo "All Docker tests completed!"
|
||||||
|
|
||||||
test-clean: ## Clean up test resources
|
test-clean: ## Clean up test resources
|
||||||
@echo "Cleaning up test resources..."
|
@echo "Cleaning up test resources..."
|
||||||
$(DOCKER_COMPOSE) -f tests/docker-compose.test.yml down -v --remove-orphans
|
$(DOCKER_COMPOSE) -f tests/docker-compose.test.yml down -v --remove-orphans
|
||||||
rm -f coverage.out coverage.html
|
rm -f coverage.out coverage.html
|
||||||
@echo "Test cleanup completed!"
|
@echo "Test cleanup completed!"
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Docker Commands
|
# Docker Commands
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
docker-build: ## Build Docker images
|
docker-build: ## Build Docker images
|
||||||
@echo "Building Docker images..."
|
@echo "Building Docker images..."
|
||||||
$(DOCKER_COMPOSE) build
|
$(DOCKER_COMPOSE) build
|
||||||
@echo "Docker images built!"
|
@echo "Docker images built!"
|
||||||
|
|
||||||
docker-up: ## Start all services with Docker Compose
|
docker-up: ## Start all services with Docker Compose
|
||||||
@echo "Starting services..."
|
@echo "Starting services..."
|
||||||
$(DOCKER_COMPOSE) up -d
|
$(DOCKER_COMPOSE) up -d
|
||||||
@echo "Services started!"
|
@echo "Services started!"
|
||||||
|
|
||||||
docker-down: ## Stop all services
|
docker-down: ## Stop all services
|
||||||
@echo "Stopping services..."
|
@echo "Stopping services..."
|
||||||
$(DOCKER_COMPOSE) down
|
$(DOCKER_COMPOSE) down
|
||||||
@echo "Services stopped!"
|
@echo "Services stopped!"
|
||||||
|
|
||||||
docker-logs: ## View logs
|
docker-logs: ## View logs
|
||||||
$(DOCKER_COMPOSE) logs -f
|
$(DOCKER_COMPOSE) logs -f
|
||||||
|
|
||||||
docker-ps: ## View running containers
|
docker-ps: ## View running containers
|
||||||
$(DOCKER_COMPOSE) ps
|
$(DOCKER_COMPOSE) ps
|
||||||
|
|
||||||
docker-clean: ## Remove all containers and volumes
|
docker-clean: ## Remove all containers and volumes
|
||||||
@echo "Cleaning Docker resources..."
|
@echo "Cleaning Docker resources..."
|
||||||
$(DOCKER_COMPOSE) down -v --remove-orphans
|
$(DOCKER_COMPOSE) down -v --remove-orphans
|
||||||
@echo "Docker resources cleaned!"
|
@echo "Docker resources cleaned!"
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Database Commands
|
# Database Commands
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
db-migrate: ## Run database migrations
|
db-migrate: ## Run database migrations
|
||||||
@echo "Running database migrations..."
|
@echo "Running database migrations..."
|
||||||
psql -h localhost -U mpc_user -d mpc_system -f migrations/001_init_schema.sql
|
psql -h localhost -U mpc_user -d mpc_system -f migrations/001_init_schema.sql
|
||||||
@echo "Migrations completed!"
|
@echo "Migrations completed!"
|
||||||
|
|
||||||
db-reset: ## Reset database (drop and recreate)
|
db-reset: ## Reset database (drop and recreate)
|
||||||
@echo "Resetting database..."
|
@echo "Resetting database..."
|
||||||
psql -h localhost -U mpc_user -d postgres -c "DROP DATABASE IF EXISTS mpc_system"
|
psql -h localhost -U mpc_user -d postgres -c "DROP DATABASE IF EXISTS mpc_system"
|
||||||
psql -h localhost -U mpc_user -d postgres -c "CREATE DATABASE mpc_system"
|
psql -h localhost -U mpc_user -d postgres -c "CREATE DATABASE mpc_system"
|
||||||
$(MAKE) db-migrate
|
$(MAKE) db-migrate
|
||||||
@echo "Database reset completed!"
|
@echo "Database reset completed!"
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Mobile SDK Commands
|
# Mobile SDK Commands
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
build-android-sdk: ## Build Android SDK
|
build-android-sdk: ## Build Android SDK
|
||||||
@echo "Building Android SDK..."
|
@echo "Building Android SDK..."
|
||||||
gomobile bind -target=android -o sdk/android/mpcsdk.aar ./sdk/go
|
gomobile bind -target=android -o sdk/android/mpcsdk.aar ./sdk/go
|
||||||
@echo "Android SDK built!"
|
@echo "Android SDK built!"
|
||||||
|
|
||||||
build-ios-sdk: ## Build iOS SDK
|
build-ios-sdk: ## Build iOS SDK
|
||||||
@echo "Building iOS SDK..."
|
@echo "Building iOS SDK..."
|
||||||
gomobile bind -target=ios -o sdk/ios/Mpcsdk.xcframework ./sdk/go
|
gomobile bind -target=ios -o sdk/ios/Mpcsdk.xcframework ./sdk/go
|
||||||
@echo "iOS SDK built!"
|
@echo "iOS SDK built!"
|
||||||
|
|
||||||
build-mobile-sdk: build-android-sdk build-ios-sdk ## Build all mobile SDKs
|
build-mobile-sdk: build-android-sdk build-ios-sdk ## Build all mobile SDKs
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Kubernetes Commands
|
# Kubernetes Commands
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
deploy-k8s: ## Deploy to Kubernetes
|
deploy-k8s: ## Deploy to Kubernetes
|
||||||
@echo "Deploying to Kubernetes..."
|
@echo "Deploying to Kubernetes..."
|
||||||
kubectl apply -f k8s/
|
kubectl apply -f k8s/
|
||||||
@echo "Deployed!"
|
@echo "Deployed!"
|
||||||
|
|
||||||
undeploy-k8s: ## Remove from Kubernetes
|
undeploy-k8s: ## Remove from Kubernetes
|
||||||
@echo "Removing from Kubernetes..."
|
@echo "Removing from Kubernetes..."
|
||||||
kubectl delete -f k8s/
|
kubectl delete -f k8s/
|
||||||
@echo "Removed!"
|
@echo "Removed!"
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Development Helpers
|
# Development Helpers
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
run-coordinator: ## Run session-coordinator locally
|
run-coordinator: ## Run session-coordinator locally
|
||||||
$(GO) run ./services/session-coordinator/cmd/server
|
$(GO) run ./services/session-coordinator/cmd/server
|
||||||
|
|
||||||
run-router: ## Run message-router locally
|
run-router: ## Run message-router locally
|
||||||
$(GO) run ./services/message-router/cmd/server
|
$(GO) run ./services/message-router/cmd/server
|
||||||
|
|
||||||
run-party: ## Run server-party locally
|
run-party: ## Run server-party locally
|
||||||
$(GO) run ./services/server-party/cmd/server
|
$(GO) run ./services/server-party/cmd/server
|
||||||
|
|
||||||
run-account: ## Run account service locally
|
run-account: ## Run account service locally
|
||||||
$(GO) run ./services/account/cmd/server
|
$(GO) run ./services/account/cmd/server
|
||||||
|
|
||||||
dev: docker-up ## Start development environment
|
dev: docker-up ## Start development environment
|
||||||
@echo "Development environment is ready!"
|
@echo "Development environment is ready!"
|
||||||
@echo " PostgreSQL: localhost:5432"
|
@echo " PostgreSQL: localhost:5432"
|
||||||
@echo " Redis: localhost:6379"
|
@echo " Redis: localhost:6379"
|
||||||
@echo " RabbitMQ: localhost:5672 (management: localhost:15672)"
|
@echo " RabbitMQ: localhost:5672 (management: localhost:15672)"
|
||||||
@echo " Consul: localhost:8500"
|
@echo " Consul: localhost:8500"
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Release Commands
|
# Release Commands
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
release: lint test build ## Create a release
|
release: lint test build ## Create a release
|
||||||
@echo "Creating release $(VERSION)..."
|
@echo "Creating release $(VERSION)..."
|
||||||
@echo "Release created!"
|
@echo "Release created!"
|
||||||
|
|
||||||
version: ## Show version
|
version: ## Show version
|
||||||
@echo "Version: $(VERSION)"
|
@echo "Version: $(VERSION)"
|
||||||
@echo "Build Time: $(BUILD_TIME)"
|
@echo "Build Time: $(BUILD_TIME)"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,295 @@
|
||||||
|
# Party Role Labels Implementation - Verification Report
|
||||||
|
|
||||||
|
**Date**: 2025-12-05
|
||||||
|
**Commit**: e975e9d - "feat(mpc-system): implement party role labels with strict persistent-only default"
|
||||||
|
**Environment**: Docker Compose (Local Development)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Implementation Summary
|
||||||
|
|
||||||
|
### 1.1 Overview
|
||||||
|
Implemented Party Role Labels (Solution 1) to differentiate between three types of server parties:
|
||||||
|
- **Persistent**: Stores key shares in database permanently
|
||||||
|
- **Delegate**: Generates user shares and returns them to caller (doesn't store)
|
||||||
|
- **Temporary**: For ad-hoc operations
|
||||||
|
|
||||||
|
### 1.2 Core Changes
|
||||||
|
|
||||||
|
#### Files Modified
|
||||||
|
1. `services/session-coordinator/application/ports/output/party_pool_port.go`
|
||||||
|
- Added `PartyRole` enum (persistent, delegate, temporary)
|
||||||
|
- Added `PartyEndpoint.Role` field
|
||||||
|
- Added `PartySelectionFilter` struct with role filtering
|
||||||
|
- Added `SelectPartiesWithFilter()` and `GetAvailablePartiesByRole()` methods
|
||||||
|
|
||||||
|
2. `services/session-coordinator/infrastructure/k8s/party_discovery.go`
|
||||||
|
- Implemented role extraction from K8s pod labels (`party-role`)
|
||||||
|
- Implemented `GetAvailablePartiesByRole()` for role-based filtering
|
||||||
|
- Implemented `SelectPartiesWithFilter()` with role and count requirements
|
||||||
|
- Default role: `persistent` if label not found
|
||||||
|
|
||||||
|
3. `services/session-coordinator/application/ports/input/session_management_port.go`
|
||||||
|
- Added `PartyComposition` struct with role-based party counts
|
||||||
|
- Added optional `PartyComposition` field to `CreateSessionInput`
|
||||||
|
|
||||||
|
4. `services/session-coordinator/application/use_cases/create_session.go`
|
||||||
|
- Implemented strict persistent-only default policy (lines 102-114)
|
||||||
|
- Implemented `selectPartiesByComposition()` method with empty composition validation (lines 224-284)
|
||||||
|
- Added clear error messages for insufficient parties
|
||||||
|
|
||||||
|
5. `k8s/server-party-deployment.yaml`
|
||||||
|
- Added label: `party-role: persistent` (line 25)
|
||||||
|
|
||||||
|
6. `k8s/server-party-api-deployment.yaml` (NEW FILE)
|
||||||
|
- New deployment for delegate parties
|
||||||
|
- Added label: `party-role: delegate` (line 25)
|
||||||
|
- Replicas: 2 (for generating user shares)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Security Policy Implementation
|
||||||
|
|
||||||
|
### 2.1 Strict Persistent-Only Default
|
||||||
|
When `PartyComposition` is **nil** (not specified):
|
||||||
|
- System MUST select only `persistent` parties
|
||||||
|
- If insufficient persistent parties available → **Fail immediately with clear error**
|
||||||
|
- NO automatic fallback to delegate/temporary parties
|
||||||
|
- Error message: "insufficient persistent parties: need N persistent parties but not enough available. Use PartyComposition to specify custom party requirements"
|
||||||
|
|
||||||
|
**Code Reference**: [create_session.go:102-114](c:\Users\dong\Desktop\rwadurian\backend\mpc-system\services\session-coordinator\application\use_cases\create_session.go#L102-L114)
|
||||||
|
|
||||||
|
### 2.2 Empty PartyComposition Validation
|
||||||
|
When `PartyComposition` is specified but all counts are 0:
|
||||||
|
- System returns error: "PartyComposition specified but no parties selected: all counts are zero and no custom filters provided"
|
||||||
|
- Prevents accidental bypass of persistent-only requirement
|
||||||
|
|
||||||
|
**Code Reference**: [create_session.go:279-281](c:\Users\dong\Desktop\rwadurian\backend\mpc-system\services\session-coordinator\application\use_cases\create_session.go#L279-L281)
|
||||||
|
|
||||||
|
### 2.3 Threshold Security Guarantee
|
||||||
|
- Default policy ensures MPC threshold security by using only persistent parties
|
||||||
|
- Persistent parties store shares in database, ensuring T-of-N shares are always available for future sign operations
|
||||||
|
- Delegate parties (which don't store shares) are only used when explicitly specified via `PartyComposition`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Docker Compose Deployment Verification
|
||||||
|
|
||||||
|
### 3.1 Build Status
|
||||||
|
**Command**: `./deploy.sh build`
|
||||||
|
**Status**: ✅ SUCCESS
|
||||||
|
**Images Built**:
|
||||||
|
1. mpc-system-postgres (postgres:15-alpine)
|
||||||
|
2. mpc-system-rabbitmq (rabbitmq:3-management-alpine)
|
||||||
|
3. mpc-system-redis (redis:7-alpine)
|
||||||
|
4. mpc-system-session-coordinator
|
||||||
|
5. mpc-system-message-router
|
||||||
|
6. mpc-system-server-party-1/2/3
|
||||||
|
7. mpc-system-server-party-api
|
||||||
|
8. mpc-system-account-service
|
||||||
|
|
||||||
|
### 3.2 Deployment Status
|
||||||
|
**Command**: `./deploy.sh up`
|
||||||
|
**Status**: ✅ SUCCESS
|
||||||
|
**Services Running** (10 containers):
|
||||||
|
|
||||||
|
| Service | Status | Ports | Notes |
|
||||||
|
|---------|--------|-------|-------|
|
||||||
|
| mpc-postgres | Healthy | 5432 (internal) | PostgreSQL 15 |
|
||||||
|
| mpc-rabbitmq | Healthy | 5672, 15672 (internal) | Message broker |
|
||||||
|
| mpc-redis | Healthy | 6379 (internal) | Cache store |
|
||||||
|
| mpc-session-coordinator | Healthy | 8081:8080 | Core orchestration |
|
||||||
|
| mpc-message-router | Healthy | 8082:8080 | Message routing |
|
||||||
|
| mpc-server-party-1 | Healthy | 50051, 8080 (internal) | Persistent party |
|
||||||
|
| mpc-server-party-2 | Healthy | 50051, 8080 (internal) | Persistent party |
|
||||||
|
| mpc-server-party-3 | Healthy | 50051, 8080 (internal) | Persistent party |
|
||||||
|
| mpc-server-party-api | Healthy | 8083:8080 | Delegate party |
|
||||||
|
| mpc-account-service | Healthy | 4000:8080 | Application service |
|
||||||
|
|
||||||
|
### 3.3 Health Check Results
|
||||||
|
```bash
|
||||||
|
# Session Coordinator
|
||||||
|
$ curl http://localhost:8081/health
|
||||||
|
{"service":"session-coordinator","status":"healthy"}
|
||||||
|
|
||||||
|
# Account Service
|
||||||
|
$ curl http://localhost:4000/health
|
||||||
|
{"service":"account","status":"healthy"}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Status**: ✅ All services responding to health checks
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Known Limitations in Docker Compose Environment
|
||||||
|
|
||||||
|
### 4.1 K8s Party Discovery Not Available
|
||||||
|
**Log Message**:
|
||||||
|
```
|
||||||
|
{"level":"warn","message":"K8s party discovery not available, will use dynamic join mode",
|
||||||
|
"error":"failed to create k8s config: stat /home/mpc/.kube/config: no such file or directory"}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Impact**:
|
||||||
|
- Party role labels (`party-role`) from K8s deployments are not accessible in Docker Compose
|
||||||
|
- System falls back to dynamic join mode (universal join tokens)
|
||||||
|
- `PartyPoolPort` is not available, so `selectPartiesByComposition()` logic is not exercised
|
||||||
|
|
||||||
|
**Why This Happens**:
|
||||||
|
- Docker Compose doesn't provide K8s API access
|
||||||
|
- Party discovery requires K8s Service Discovery and pod label queries
|
||||||
|
- This is expected behavior for non-K8s environments
|
||||||
|
|
||||||
|
### 4.2 Party Role Labels Not Testable in Docker Compose
|
||||||
|
The following features cannot be tested in Docker Compose:
|
||||||
|
1. Role-based party filtering (`SelectPartiesWithFilter`)
|
||||||
|
2. `PartyComposition`-based party selection
|
||||||
|
3. Strict persistent-only default policy
|
||||||
|
4. K8s pod label reading (`party-role`)
|
||||||
|
|
||||||
|
**These features require actual Kubernetes deployment to test.**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. What Was Verified
|
||||||
|
|
||||||
|
### 5.1 Code Compilation ✅
|
||||||
|
- All modified Go files compile successfully
|
||||||
|
- No syntax errors or type errors
|
||||||
|
- Build completes on both Windows (local) and WSL (Docker)
|
||||||
|
|
||||||
|
### 5.2 Service Deployment ✅
|
||||||
|
- All 10 services start successfully
|
||||||
|
- All health checks pass
|
||||||
|
- Services can connect to each other (gRPC connectivity verified in logs)
|
||||||
|
- Database connections established
|
||||||
|
- Message broker connections established
|
||||||
|
|
||||||
|
### 5.3 Code Logic Review ✅
|
||||||
|
- Strict persistent-only default policy correctly implemented
|
||||||
|
- Empty `PartyComposition` validation prevents loophole
|
||||||
|
- Clear error messages for insufficient parties
|
||||||
|
- Role extraction from K8s pod labels correctly implemented
|
||||||
|
- Role-based filtering logic correct
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. What Cannot Be Verified Without K8s
|
||||||
|
|
||||||
|
### 6.1 Runtime Behavior
|
||||||
|
1. **Party Discovery**: K8s pod label queries
|
||||||
|
2. **Role Filtering**: Actual filtering by `party-role` label values
|
||||||
|
3. **Persistent-Only Policy**: Enforcement when persistent parties insufficient
|
||||||
|
4. **Error Messages**: Actual error messages when party selection fails
|
||||||
|
5. **PartyComposition**: Custom party mix selection
|
||||||
|
|
||||||
|
### 6.2 Integration Testing
|
||||||
|
1. Creating a session with default (nil) `PartyComposition` → should select only persistent parties
|
||||||
|
2. Creating a session with insufficient persistent parties → should return clear error
|
||||||
|
3. Creating a session with empty `PartyComposition` → should return validation error
|
||||||
|
4. Creating a session with custom `PartyComposition` → should select correct party mix
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Next Steps for Full Verification
|
||||||
|
|
||||||
|
### 7.1 Deploy to Kubernetes Cluster
|
||||||
|
To fully test Party Role Labels, deploy to actual K8s cluster:
|
||||||
|
```bash
|
||||||
|
# Apply K8s manifests
|
||||||
|
kubectl apply -f k8s/namespace.yaml
|
||||||
|
kubectl apply -f k8s/configmap.yaml
|
||||||
|
kubectl apply -f k8s/secrets.yaml
|
||||||
|
kubectl apply -f k8s/postgres-deployment.yaml
|
||||||
|
kubectl apply -f k8s/rabbitmq-deployment.yaml
|
||||||
|
kubectl apply -f k8s/redis-deployment.yaml
|
||||||
|
kubectl apply -f k8s/server-party-deployment.yaml
|
||||||
|
kubectl apply -f k8s/server-party-api-deployment.yaml
|
||||||
|
kubectl apply -f k8s/session-coordinator-deployment.yaml
|
||||||
|
kubectl apply -f k8s/message-router-deployment.yaml
|
||||||
|
kubectl apply -f k8s/account-service-deployment.yaml
|
||||||
|
|
||||||
|
# Verify party discovery works
|
||||||
|
kubectl logs -n mpc-system -l app=mpc-session-coordinator | grep -i "party\|role\|discovery"
|
||||||
|
|
||||||
|
# Verify pod labels are set
|
||||||
|
kubectl get pods -n mpc-system -l app=mpc-server-party -o jsonpath='{range .items[*]}{.metadata.name}{"\t"}{.metadata.labels.party-role}{"\n"}{end}'
|
||||||
|
kubectl get pods -n mpc-system -l app=mpc-server-party-api -o jsonpath='{range .items[*]}{.metadata.name}{"\t"}{.metadata.labels.party-role}{"\n"}{end}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7.2 Integration Testing in K8s
|
||||||
|
1. **Test Default Persistent-Only Selection**:
|
||||||
|
```bash
|
||||||
|
curl -X POST http://<account-service>/api/v1/accounts \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"user_id": "test-user-1"}'
|
||||||
|
|
||||||
|
# Expected: Session created with 3 persistent parties
|
||||||
|
# Check logs: kubectl logs -n mpc-system -l app=mpc-session-coordinator | grep "selected persistent parties by default"
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Test Insufficient Persistent Parties Error**:
|
||||||
|
```bash
|
||||||
|
# Scale down persistent parties to 2
|
||||||
|
kubectl scale deployment mpc-server-party -n mpc-system --replicas=2
|
||||||
|
|
||||||
|
# Try creating session requiring 3 parties
|
||||||
|
curl -X POST http://<account-service>/api/v1/accounts \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"user_id": "test-user-2"}'
|
||||||
|
|
||||||
|
# Expected: HTTP 500 error with message "insufficient persistent parties: need 3 persistent parties but not enough available"
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Test Empty PartyComposition Validation**:
|
||||||
|
- Requires API endpoint that accepts `PartyComposition` parameter
|
||||||
|
- Send request with `PartyComposition: {PersistentCount: 0, DelegateCount: 0, TemporaryCount: 0}`
|
||||||
|
- Expected: HTTP 400 error with message "PartyComposition specified but no parties selected"
|
||||||
|
|
||||||
|
4. **Test Custom PartyComposition**:
|
||||||
|
- Send request with `PartyComposition: {PersistentCount: 2, DelegateCount: 1}`
|
||||||
|
- Expected: Session created with 2 persistent + 1 delegate party
|
||||||
|
- Verify party roles in session data
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Conclusion
|
||||||
|
|
||||||
|
### 8.1 Implementation Status: ✅ COMPLETE
|
||||||
|
- All code changes implemented correctly
|
||||||
|
- Strict persistent-only default policy enforced
|
||||||
|
- Empty `PartyComposition` validation prevents loophole
|
||||||
|
- Clear error messages for insufficient parties
|
||||||
|
- Backward compatibility maintained (optional `PartyComposition`)
|
||||||
|
|
||||||
|
### 8.2 Deployment Status: ✅ SUCCESS (Docker Compose)
|
||||||
|
- All services build successfully
|
||||||
|
- All services deploy successfully
|
||||||
|
- All services healthy and responding
|
||||||
|
- Inter-service connectivity verified
|
||||||
|
|
||||||
|
### 8.3 Verification Status: ⚠️ PARTIAL
|
||||||
|
- ✅ Code compilation and logic review
|
||||||
|
- ✅ Docker Compose deployment
|
||||||
|
- ✅ Service health checks
|
||||||
|
- ❌ Party role filtering runtime behavior (requires K8s)
|
||||||
|
- ❌ Persistent-only policy enforcement (requires K8s)
|
||||||
|
- ❌ Integration testing (requires K8s)
|
||||||
|
|
||||||
|
### 8.4 Readiness for Production
|
||||||
|
**Code Readiness**: ✅ READY
|
||||||
|
**Testing Readiness**: ⚠️ REQUIRES K8S DEPLOYMENT FOR FULL TESTING
|
||||||
|
**Deployment Readiness**: ✅ READY (K8s manifests prepared)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. User Confirmation Required
|
||||||
|
|
||||||
|
The Party Role Labels implementation is complete and successfully deployed in Docker Compose. However, full runtime verification requires deploying to an actual Kubernetes cluster.
|
||||||
|
|
||||||
|
**Options**:
|
||||||
|
1. Proceed with K8s deployment for full verification
|
||||||
|
2. Accept partial verification (code review + Docker Compose deployment)
|
||||||
|
3. Create integration tests that mock K8s party discovery
|
||||||
|
|
||||||
|
Awaiting user instruction on next steps.
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,416 +1,416 @@
|
||||||
# MPC-System 真实场景验证报告
|
# MPC-System 真实场景验证报告
|
||||||
|
|
||||||
**验证时间**: 2025-12-05
|
**验证时间**: 2025-12-05
|
||||||
**验证环境**: WSL2 Ubuntu + Docker Compose
|
**验证环境**: WSL2 Ubuntu + Docker Compose
|
||||||
**系统版本**: MPC-System v1.0.0
|
**系统版本**: MPC-System v1.0.0
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 执行摘要
|
## 执行摘要
|
||||||
|
|
||||||
✅ **MPC 系统核心功能验证通过**
|
✅ **MPC 系统核心功能验证通过**
|
||||||
|
|
||||||
所有关键服务正常运行,核心 API 功能验证成功。系统已准备好进行集成测试和生产部署。
|
所有关键服务正常运行,核心 API 功能验证成功。系统已准备好进行集成测试和生产部署。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 1. 服务健康状态检查
|
## 1. 服务健康状态检查
|
||||||
|
|
||||||
### 1.1 Docker 服务状态
|
### 1.1 Docker 服务状态
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ docker compose ps
|
$ docker compose ps
|
||||||
```
|
```
|
||||||
|
|
||||||
| 服务名称 | 状态 | 端口映射 | 健康检查 |
|
| 服务名称 | 状态 | 端口映射 | 健康检查 |
|
||||||
|---------|------|----------|---------|
|
|---------|------|----------|---------|
|
||||||
| mpc-account-service | ✅ Up 28 min | 0.0.0.0:4000→8080 | healthy |
|
| mpc-account-service | ✅ Up 28 min | 0.0.0.0:4000→8080 | healthy |
|
||||||
| mpc-session-coordinator | ✅ Up 29 min | 0.0.0.0:8081→8080 | healthy |
|
| mpc-session-coordinator | ✅ Up 29 min | 0.0.0.0:8081→8080 | healthy |
|
||||||
| mpc-message-router | ✅ Up 29 min | 0.0.0.0:8082→8080 | healthy |
|
| mpc-message-router | ✅ Up 29 min | 0.0.0.0:8082→8080 | healthy |
|
||||||
| mpc-server-party-1 | ✅ Up 28 min | Internal | healthy |
|
| mpc-server-party-1 | ✅ Up 28 min | Internal | healthy |
|
||||||
| mpc-server-party-2 | ✅ Up 28 min | Internal | healthy |
|
| mpc-server-party-2 | ✅ Up 28 min | Internal | healthy |
|
||||||
| mpc-server-party-3 | ✅ Up 28 min | Internal | healthy |
|
| mpc-server-party-3 | ✅ Up 28 min | Internal | healthy |
|
||||||
| mpc-server-party-api | ✅ Up 28 min | 0.0.0.0:8083→8080 | healthy |
|
| mpc-server-party-api | ✅ Up 28 min | 0.0.0.0:8083→8080 | healthy |
|
||||||
| mpc-postgres | ✅ Up 30 min | Internal:5432 | healthy |
|
| mpc-postgres | ✅ Up 30 min | Internal:5432 | healthy |
|
||||||
| mpc-redis | ✅ Up 30 min | Internal:6379 | healthy |
|
| mpc-redis | ✅ Up 30 min | Internal:6379 | healthy |
|
||||||
| mpc-rabbitmq | ✅ Up 30 min | Internal:5672 | healthy |
|
| mpc-rabbitmq | ✅ Up 30 min | Internal:5672 | healthy |
|
||||||
|
|
||||||
**结论**: ✅ 所有 10 个服务健康运行
|
**结论**: ✅ 所有 10 个服务健康运行
|
||||||
|
|
||||||
### 1.2 Health Endpoint 测试
|
### 1.2 Health Endpoint 测试
|
||||||
|
|
||||||
#### Account Service
|
#### Account Service
|
||||||
```bash
|
```bash
|
||||||
$ curl -s http://localhost:4000/health | jq .
|
$ curl -s http://localhost:4000/health | jq .
|
||||||
```
|
```
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"service": "account",
|
"service": "account",
|
||||||
"status": "healthy"
|
"status": "healthy"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
✅ **通过**
|
✅ **通过**
|
||||||
|
|
||||||
#### Session Coordinator
|
#### Session Coordinator
|
||||||
```bash
|
```bash
|
||||||
$ curl -s http://localhost:8081/health | jq .
|
$ curl -s http://localhost:8081/health | jq .
|
||||||
```
|
```
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"service": "session-coordinator",
|
"service": "session-coordinator",
|
||||||
"status": "healthy"
|
"status": "healthy"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
✅ **通过**
|
✅ **通过**
|
||||||
|
|
||||||
#### Server Party API
|
#### Server Party API
|
||||||
```bash
|
```bash
|
||||||
$ curl -s http://localhost:8083/health | jq .
|
$ curl -s http://localhost:8083/health | jq .
|
||||||
```
|
```
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"service": "server-party-api",
|
"service": "server-party-api",
|
||||||
"status": "healthy"
|
"status": "healthy"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
✅ **通过**
|
✅ **通过**
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 2. 核心 API 功能验证
|
## 2. 核心 API 功能验证
|
||||||
|
|
||||||
### 2.1 创建 Keygen 会话 (POST /api/v1/mpc/keygen)
|
### 2.1 创建 Keygen 会话 (POST /api/v1/mpc/keygen)
|
||||||
|
|
||||||
#### 测试请求
|
#### 测试请求
|
||||||
```bash
|
```bash
|
||||||
curl -s -X POST http://localhost:4000/api/v1/mpc/keygen \
|
curl -s -X POST http://localhost:4000/api/v1/mpc/keygen \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
"threshold_n": 3,
|
"threshold_n": 3,
|
||||||
"threshold_t": 2,
|
"threshold_t": 2,
|
||||||
"participants": [
|
"participants": [
|
||||||
{"party_id": "user_device_test", "device_type": "android"},
|
{"party_id": "user_device_test", "device_type": "android"},
|
||||||
{"party_id": "server_party_1", "device_type": "server"},
|
{"party_id": "server_party_1", "device_type": "server"},
|
||||||
{"party_id": "server_party_2", "device_type": "server"}
|
{"party_id": "server_party_2", "device_type": "server"}
|
||||||
]
|
]
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 实际响应
|
#### 实际响应
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"session_id": "7e33def8-dcc8-4604-a4a0-10df1ebbeb4a",
|
"session_id": "7e33def8-dcc8-4604-a4a0-10df1ebbeb4a",
|
||||||
"session_type": "keygen",
|
"session_type": "keygen",
|
||||||
"threshold_n": 3,
|
"threshold_n": 3,
|
||||||
"threshold_t": 2,
|
"threshold_t": 2,
|
||||||
"status": "created",
|
"status": "created",
|
||||||
"join_tokens": {
|
"join_tokens": {
|
||||||
"user_device_test": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
"user_device_test": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||||
"server_party_1": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
"server_party_1": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||||
"server_party_2": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
|
"server_party_2": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 验证结果
|
#### 验证结果
|
||||||
|
|
||||||
| 验证项 | 期望值 | 实际值 | 结果 |
|
| 验证项 | 期望值 | 实际值 | 结果 |
|
||||||
|-------|-------|--------|------|
|
|-------|-------|--------|------|
|
||||||
| HTTP 状态码 | 200/201 | 200 | ✅ |
|
| HTTP 状态码 | 200/201 | 200 | ✅ |
|
||||||
| session_id 格式 | UUID | ✅ 有效 UUID | ✅ |
|
| session_id 格式 | UUID | ✅ 有效 UUID | ✅ |
|
||||||
| session_type | "keygen" | "keygen" | ✅ |
|
| session_type | "keygen" | "keygen" | ✅ |
|
||||||
| threshold_n | 3 | 3 | ✅ |
|
| threshold_n | 3 | 3 | ✅ |
|
||||||
| threshold_t | 2 | 2 | ✅ |
|
| threshold_t | 2 | 2 | ✅ |
|
||||||
| status | "created" | "created" | ✅ |
|
| status | "created" | "created" | ✅ |
|
||||||
| join_tokens 数量 | 3 | 3 | ✅ |
|
| join_tokens 数量 | 3 | 3 | ✅ |
|
||||||
| JWT Token 格式 | 有效 JWT | ✅ 有效 | ✅ |
|
| JWT Token 格式 | 有效 JWT | ✅ 有效 | ✅ |
|
||||||
|
|
||||||
**结论**: ✅ **Keygen 会话创建功能完全正常**
|
**结论**: ✅ **Keygen 会话创建功能完全正常**
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 3. E2E 测试问题分析
|
## 3. E2E 测试问题分析
|
||||||
|
|
||||||
### 3.1 问题根因
|
### 3.1 问题根因
|
||||||
|
|
||||||
原 E2E 测试失败的原因:
|
原 E2E 测试失败的原因:
|
||||||
|
|
||||||
1. **Account Service 测试 (3 个失败)**
|
1. **Account Service 测试 (3 个失败)**
|
||||||
- ❌ 问题: 测试代码期望 `account.id` 为字符串
|
- ❌ 问题: 测试代码期望 `account.id` 为字符串
|
||||||
- ✅ 实际: `AccountID` 已实现 `MarshalJSON`,正确序列化为字符串
|
- ✅ 实际: `AccountID` 已实现 `MarshalJSON`,正确序列化为字符串
|
||||||
- ✅ 根因: 测试环境配置问题,而非代码问题
|
- ✅ 根因: 测试环境配置问题,而非代码问题
|
||||||
|
|
||||||
2. **Session Coordinator 测试 (2 个失败)**
|
2. **Session Coordinator 测试 (2 个失败)**
|
||||||
- ❌ 问题: 测试请求格式与实际 API 不匹配
|
- ❌ 问题: 测试请求格式与实际 API 不匹配
|
||||||
- ✅ 实际 API: 需要 `participants` 字段 (已验证)
|
- ✅ 实际 API: 需要 `participants` 字段 (已验证)
|
||||||
- ✅ 根因: 测试代码过时,API 实现正确
|
- ✅ 根因: 测试代码过时,API 实现正确
|
||||||
|
|
||||||
### 3.2 修复建议
|
### 3.2 修复建议
|
||||||
|
|
||||||
不需要修改生产代码,只需要更新 E2E 测试代码:
|
不需要修改生产代码,只需要更新 E2E 测试代码:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// 修复前 (tests/e2e/keygen_flow_test.go)
|
// 修复前 (tests/e2e/keygen_flow_test.go)
|
||||||
type CreateSessionRequest struct {
|
type CreateSessionRequest struct {
|
||||||
SessionType string `json:"sessionType"`
|
SessionType string `json:"sessionType"`
|
||||||
ThresholdT int `json:"thresholdT"`
|
ThresholdT int `json:"thresholdT"`
|
||||||
ThresholdN int `json:"thresholdN"`
|
ThresholdN int `json:"thresholdN"`
|
||||||
CreatedBy string `json:"createdBy"`
|
CreatedBy string `json:"createdBy"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// 修复后 (应该添加 participants 字段)
|
// 修复后 (应该添加 participants 字段)
|
||||||
type CreateSessionRequest struct {
|
type CreateSessionRequest struct {
|
||||||
SessionType string `json:"sessionType"`
|
SessionType string `json:"sessionType"`
|
||||||
ThresholdT int `json:"thresholdT"`
|
ThresholdT int `json:"thresholdT"`
|
||||||
ThresholdN int `json:"thresholdN"`
|
ThresholdN int `json:"thresholdN"`
|
||||||
Participants []ParticipantInfoRequest `json:"participants"`
|
Participants []ParticipantInfoRequest `json:"participants"`
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 4. 系统架构验证
|
## 4. 系统架构验证
|
||||||
|
|
||||||
### 4.1 服务间通信测试
|
### 4.1 服务间通信测试
|
||||||
|
|
||||||
#### gRPC 内部通信
|
#### gRPC 内部通信
|
||||||
```bash
|
```bash
|
||||||
$ docker compose exec account-service nc -zv mpc-session-coordinator 50051
|
$ docker compose exec account-service nc -zv mpc-session-coordinator 50051
|
||||||
```
|
```
|
||||||
✅ **连接成功**
|
✅ **连接成功**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ docker compose exec session-coordinator nc -zv mpc-message-router 50051
|
$ docker compose exec session-coordinator nc -zv mpc-message-router 50051
|
||||||
```
|
```
|
||||||
✅ **连接成功**
|
✅ **连接成功**
|
||||||
|
|
||||||
### 4.2 数据库连接
|
### 4.2 数据库连接
|
||||||
```bash
|
```bash
|
||||||
$ docker compose exec account-service env | grep DATABASE
|
$ docker compose exec account-service env | grep DATABASE
|
||||||
```
|
```
|
||||||
✅ **配置正确**
|
✅ **配置正确**
|
||||||
|
|
||||||
### 4.3 消息队列
|
### 4.3 消息队列
|
||||||
```bash
|
```bash
|
||||||
$ docker compose exec rabbitmq rabbitmqctl status
|
$ docker compose exec rabbitmq rabbitmqctl status
|
||||||
```
|
```
|
||||||
✅ **RabbitMQ 正常运行**
|
✅ **RabbitMQ 正常运行**
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 5. 性能指标
|
## 5. 性能指标
|
||||||
|
|
||||||
### 5.1 Keygen 会话创建性能
|
### 5.1 Keygen 会话创建性能
|
||||||
|
|
||||||
| 指标 | 值 |
|
| 指标 | 值 |
|
||||||
|-----|---|
|
|-----|---|
|
||||||
| 平均响应时间 | < 100ms |
|
| 平均响应时间 | < 100ms |
|
||||||
| 成功率 | 100% |
|
| 成功率 | 100% |
|
||||||
| 并发支持 | 未测试 |
|
| 并发支持 | 未测试 |
|
||||||
|
|
||||||
### 5.2 资源使用
|
### 5.2 资源使用
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ docker stats --no-stream
|
$ docker stats --no-stream
|
||||||
```
|
```
|
||||||
|
|
||||||
| 服务 | CPU | 内存 | 状态 |
|
| 服务 | CPU | 内存 | 状态 |
|
||||||
|-----|-----|------|------|
|
|-----|-----|------|------|
|
||||||
| account-service | ~1% | ~50MB | 正常 |
|
| account-service | ~1% | ~50MB | 正常 |
|
||||||
| session-coordinator | ~1% | ~45MB | 正常 |
|
| session-coordinator | ~1% | ~45MB | 正常 |
|
||||||
| message-router | ~1% | ~42MB | 正常 |
|
| message-router | ~1% | ~42MB | 正常 |
|
||||||
| server-party-1/2/3 | ~0.5% | ~40MB | 正常 |
|
| server-party-1/2/3 | ~0.5% | ~40MB | 正常 |
|
||||||
| postgres | ~1% | ~30MB | 正常 |
|
| postgres | ~1% | ~30MB | 正常 |
|
||||||
|
|
||||||
✅ **资源使用合理**
|
✅ **资源使用合理**
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 6. 安全性验证
|
## 6. 安全性验证
|
||||||
|
|
||||||
### 6.1 JWT Token 验证
|
### 6.1 JWT Token 验证
|
||||||
|
|
||||||
解析 Join Token:
|
解析 Join Token:
|
||||||
```bash
|
```bash
|
||||||
$ echo "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." | base64 -d
|
$ echo "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." | base64 -d
|
||||||
```
|
```
|
||||||
|
|
||||||
Token 包含字段:
|
Token 包含字段:
|
||||||
- ✅ `session_id`: 会话 ID
|
- ✅ `session_id`: 会话 ID
|
||||||
- ✅ `party_id`: 参与方 ID
|
- ✅ `party_id`: 参与方 ID
|
||||||
- ✅ `token_type`: "join"
|
- ✅ `token_type`: "join"
|
||||||
- ✅ `exp`: 过期时间 (10 分钟)
|
- ✅ `exp`: 过期时间 (10 分钟)
|
||||||
- ✅ `iss`: "mpc-system"
|
- ✅ `iss`: "mpc-system"
|
||||||
|
|
||||||
**结论**: ✅ JWT Token 格式正确,安全性符合标准
|
**结论**: ✅ JWT Token 格式正确,安全性符合标准
|
||||||
|
|
||||||
### 6.2 API 认证
|
### 6.2 API 认证
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ curl -s http://localhost:4000/api/v1/mpc/keygen
|
$ curl -s http://localhost:4000/api/v1/mpc/keygen
|
||||||
```
|
```
|
||||||
✅ 当前未启用 API Key 验证 (开发模式)
|
✅ 当前未启用 API Key 验证 (开发模式)
|
||||||
⚠️ **生产环境需启用 `X-API-Key` header 认证**
|
⚠️ **生产环境需启用 `X-API-Key` header 认证**
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 7. 集成建议
|
## 7. 集成建议
|
||||||
|
|
||||||
### 7.1 后端服务集成步骤
|
### 7.1 后端服务集成步骤
|
||||||
|
|
||||||
1. **环境配置**
|
1. **环境配置**
|
||||||
```yaml
|
```yaml
|
||||||
# docker-compose.yml
|
# docker-compose.yml
|
||||||
services:
|
services:
|
||||||
your-backend:
|
your-backend:
|
||||||
environment:
|
environment:
|
||||||
- MPC_BASE_URL=http://mpc-account-service:4000
|
- MPC_BASE_URL=http://mpc-account-service:4000
|
||||||
- MPC_API_KEY=your_secure_api_key
|
- MPC_API_KEY=your_secure_api_key
|
||||||
```
|
```
|
||||||
|
|
||||||
2. **创建钱包示例**
|
2. **创建钱包示例**
|
||||||
```bash
|
```bash
|
||||||
POST http://mpc-account-service:4000/api/v1/mpc/keygen
|
POST http://mpc-account-service:4000/api/v1/mpc/keygen
|
||||||
Content-Type: application/json
|
Content-Type: application/json
|
||||||
|
|
||||||
{
|
{
|
||||||
"threshold_n": 3,
|
"threshold_n": 3,
|
||||||
"threshold_t": 2,
|
"threshold_t": 2,
|
||||||
"participants": [...]
|
"participants": [...]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
3. **生成用户分片**
|
3. **生成用户分片**
|
||||||
```bash
|
```bash
|
||||||
POST http://mpc-server-party-api:8083/api/v1/keygen/generate-user-share
|
POST http://mpc-server-party-api:8083/api/v1/keygen/generate-user-share
|
||||||
Content-Type: application/json
|
Content-Type: application/json
|
||||||
|
|
||||||
{
|
{
|
||||||
"session_id": "uuid",
|
"session_id": "uuid",
|
||||||
"party_id": "user_device",
|
"party_id": "user_device",
|
||||||
"join_token": "jwt_token"
|
"join_token": "jwt_token"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### 7.2 推荐的集成架构
|
### 7.2 推荐的集成架构
|
||||||
|
|
||||||
```
|
```
|
||||||
┌─────────────────────────────────────┐
|
┌─────────────────────────────────────┐
|
||||||
│ Your Backend (api-gateway) │
|
│ Your Backend (api-gateway) │
|
||||||
│ ↓ │
|
│ ↓ │
|
||||||
│ MPC Client SDK (Go/Python/JS) │
|
│ MPC Client SDK (Go/Python/JS) │
|
||||||
└─────────────────┬───────────────────┘
|
└─────────────────┬───────────────────┘
|
||||||
│
|
│
|
||||||
▼
|
▼
|
||||||
┌─────────────────────────────────────┐
|
┌─────────────────────────────────────┐
|
||||||
│ MPC-System (Docker Compose) │
|
│ MPC-System (Docker Compose) │
|
||||||
│ ┌────────────────────────────┐ │
|
│ ┌────────────────────────────┐ │
|
||||||
│ │ account-service:4000 │ │
|
│ │ account-service:4000 │ │
|
||||||
│ └────────────────────────────┘ │
|
│ └────────────────────────────┘ │
|
||||||
└─────────────────────────────────────┘
|
└─────────────────────────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 8. 已知问题和限制
|
## 8. 已知问题和限制
|
||||||
|
|
||||||
### 8.1 当前限制
|
### 8.1 当前限制
|
||||||
|
|
||||||
1. ⚠️ **Server Party 未真正执行 TSS 协议**
|
1. ⚠️ **Server Party 未真正执行 TSS 协议**
|
||||||
- 当前实现: Server Parties 启动但未完全参与 keygen
|
- 当前实现: Server Parties 启动但未完全参与 keygen
|
||||||
- 影响: 用户分片生成可能需要完整实现
|
- 影响: 用户分片生成可能需要完整实现
|
||||||
- 解决: 需要完善 Server Party 的 TSS 协议集成
|
- 解决: 需要完善 Server Party 的 TSS 协议集成
|
||||||
|
|
||||||
2. ⚠️ **Account Service 未持久化账户**
|
2. ⚠️ **Account Service 未持久化账户**
|
||||||
- 当前: 创建会话成功,但未真正创建账户记录
|
- 当前: 创建会话成功,但未真正创建账户记录
|
||||||
- 影响: Sign 会话可能因账户不存在而失败
|
- 影响: Sign 会话可能因账户不存在而失败
|
||||||
- 解决: 需要完整的账户创建流程 (keygen → store shares → create account)
|
- 解决: 需要完整的账户创建流程 (keygen → store shares → create account)
|
||||||
|
|
||||||
### 8.2 待完善功能
|
### 8.2 待完善功能
|
||||||
|
|
||||||
- [ ] 完整的 TSS Keygen 协议执行 (30-90秒)
|
- [ ] 完整的 TSS Keygen 协议执行 (30-90秒)
|
||||||
- [ ] 完整的 TSS Signing 协议执行 (5-15秒)
|
- [ ] 完整的 TSS Signing 协议执行 (5-15秒)
|
||||||
- [ ] 密钥分片加密存储到数据库
|
- [ ] 密钥分片加密存储到数据库
|
||||||
- [ ] 账户恢复流程
|
- [ ] 账户恢复流程
|
||||||
- [ ] API 密钥认证 (生产环境)
|
- [ ] API 密钥认证 (生产环境)
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 9. 结论
|
## 9. 结论
|
||||||
|
|
||||||
### 9.1 验证结果总结
|
### 9.1 验证结果总结
|
||||||
|
|
||||||
| 验证项 | 状态 | 说明 |
|
| 验证项 | 状态 | 说明 |
|
||||||
|-------|------|------|
|
|-------|------|------|
|
||||||
| 服务部署 | ✅ 通过 | 所有 10 个服务健康运行 |
|
| 服务部署 | ✅ 通过 | 所有 10 个服务健康运行 |
|
||||||
| Health Check | ✅ 通过 | 所有 health endpoints 正常 |
|
| Health Check | ✅ 通过 | 所有 health endpoints 正常 |
|
||||||
| Keygen API | ✅ 通过 | 会话创建成功,响应格式正确 |
|
| Keygen API | ✅ 通过 | 会话创建成功,响应格式正确 |
|
||||||
| JWT Token | ✅ 通过 | Token 生成正确,包含必要字段 |
|
| JWT Token | ✅ 通过 | Token 生成正确,包含必要字段 |
|
||||||
| 服务通信 | ✅ 通过 | gRPC 内部通信正常 |
|
| 服务通信 | ✅ 通过 | gRPC 内部通信正常 |
|
||||||
| 数据库 | ✅ 通过 | PostgreSQL 健康运行 |
|
| 数据库 | ✅ 通过 | PostgreSQL 健康运行 |
|
||||||
| 消息队列 | ✅ 通过 | RabbitMQ 正常工作 |
|
| 消息队列 | ✅ 通过 | RabbitMQ 正常工作 |
|
||||||
| E2E 测试 | ⚠️ 部分 | 测试代码需更新,API 实现正确 |
|
| E2E 测试 | ⚠️ 部分 | 测试代码需更新,API 实现正确 |
|
||||||
| TSS 协议 | ⚠️ 待完善 | 架构正确,需实现完整协议流程 |
|
| TSS 协议 | ⚠️ 待完善 | 架构正确,需实现完整协议流程 |
|
||||||
|
|
||||||
### 9.2 系统成熟度评估
|
### 9.2 系统成熟度评估
|
||||||
|
|
||||||
**当前阶段**: **Alpha** (核心架构完成,基础功能可用)
|
**当前阶段**: **Alpha** (核心架构完成,基础功能可用)
|
||||||
|
|
||||||
**下一阶段目标**: **Beta** (完整 TSS 协议,可进行端到端测试)
|
**下一阶段目标**: **Beta** (完整 TSS 协议,可进行端到端测试)
|
||||||
|
|
||||||
**生产就绪度**: **60%**
|
**生产就绪度**: **60%**
|
||||||
|
|
||||||
✅ 已完成:
|
✅ 已完成:
|
||||||
- 微服务架构完整
|
- 微服务架构完整
|
||||||
- API 设计合理
|
- API 设计合理
|
||||||
- 服务部署成功
|
- 服务部署成功
|
||||||
- 基础功能可用
|
- 基础功能可用
|
||||||
|
|
||||||
⚠️ 待完善:
|
⚠️ 待完善:
|
||||||
- 完整 TSS 协议执行
|
- 完整 TSS 协议执行
|
||||||
- 密钥分片存储
|
- 密钥分片存储
|
||||||
- 完整的端到端流程
|
- 完整的端到端流程
|
||||||
- 安全性加固 (API Key, TLS)
|
- 安全性加固 (API Key, TLS)
|
||||||
|
|
||||||
### 9.3 推荐行动
|
### 9.3 推荐行动
|
||||||
|
|
||||||
**立即可做**:
|
**立即可做**:
|
||||||
1. ✅ 使用当前系统进行 API 集成开发
|
1. ✅ 使用当前系统进行 API 集成开发
|
||||||
2. ✅ 基于现有 API 开发客户端 SDK
|
2. ✅ 基于现有 API 开发客户端 SDK
|
||||||
3. ✅ 编写集成文档和示例代码
|
3. ✅ 编写集成文档和示例代码
|
||||||
|
|
||||||
**短期 (1-2 周)**:
|
**短期 (1-2 周)**:
|
||||||
1. 完善 Server Party 的 TSS 协议实现
|
1. 完善 Server Party 的 TSS 协议实现
|
||||||
2. 实现完整的 Keygen 流程 (含分片存储)
|
2. 实现完整的 Keygen 流程 (含分片存储)
|
||||||
3. 实现完整的 Sign 流程
|
3. 实现完整的 Sign 流程
|
||||||
4. 更新 E2E 测试代码
|
4. 更新 E2E 测试代码
|
||||||
|
|
||||||
**中期 (1 个月)**:
|
**中期 (1 个月)**:
|
||||||
1. 生产环境安全加固
|
1. 生产环境安全加固
|
||||||
2. 性能优化和压力测试
|
2. 性能优化和压力测试
|
||||||
3. 完整的监控和告警
|
3. 完整的监控和告警
|
||||||
4. 灾难恢复方案
|
4. 灾难恢复方案
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 10. 附录
|
## 10. 附录
|
||||||
|
|
||||||
### 10.1 相关文档
|
### 10.1 相关文档
|
||||||
|
|
||||||
- [MPC 集成指南](MPC_INTEGRATION_GUIDE.md)
|
- [MPC 集成指南](MPC_INTEGRATION_GUIDE.md)
|
||||||
- [API 参考文档](docs/02-api-reference.md)
|
- [API 参考文档](docs/02-api-reference.md)
|
||||||
- [架构设计文档](docs/01-architecture.md)
|
- [架构设计文档](docs/01-architecture.md)
|
||||||
- [部署指南](README.md)
|
- [部署指南](README.md)
|
||||||
|
|
||||||
### 10.2 联系支持
|
### 10.2 联系支持
|
||||||
|
|
||||||
- GitHub Issues: https://github.com/rwadurian/mpc-system/issues
|
- GitHub Issues: https://github.com/rwadurian/mpc-system/issues
|
||||||
- 技术文档: docs/
|
- 技术文档: docs/
|
||||||
- 集成示例: examples/
|
- 集成示例: examples/
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**报告生成**: Claude Code
|
**报告生成**: Claude Code
|
||||||
**验证人员**: 自动化验证
|
**验证人员**: 自动化验证
|
||||||
**日期**: 2025-12-05
|
**日期**: 2025-12-05
|
||||||
**版本**: v1.0.0
|
**版本**: v1.0.0
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,333 +1,333 @@
|
||||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// - protoc-gen-go-grpc v1.3.0
|
// - protoc-gen-go-grpc v1.3.0
|
||||||
// - protoc v3.12.4
|
// - protoc v3.12.4
|
||||||
// source: api/proto/session_coordinator.proto
|
// source: api/proto/session_coordinator.proto
|
||||||
|
|
||||||
package coordinator
|
package coordinator
|
||||||
|
|
||||||
import (
|
import (
|
||||||
context "context"
|
context "context"
|
||||||
grpc "google.golang.org/grpc"
|
grpc "google.golang.org/grpc"
|
||||||
codes "google.golang.org/grpc/codes"
|
codes "google.golang.org/grpc/codes"
|
||||||
status "google.golang.org/grpc/status"
|
status "google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This is a compile-time assertion to ensure that this generated file
|
// This is a compile-time assertion to ensure that this generated file
|
||||||
// is compatible with the grpc package it is being compiled against.
|
// is compatible with the grpc package it is being compiled against.
|
||||||
// Requires gRPC-Go v1.32.0 or later.
|
// Requires gRPC-Go v1.32.0 or later.
|
||||||
const _ = grpc.SupportPackageIsVersion7
|
const _ = grpc.SupportPackageIsVersion7
|
||||||
|
|
||||||
const (
|
const (
|
||||||
SessionCoordinator_CreateSession_FullMethodName = "/mpc.coordinator.v1.SessionCoordinator/CreateSession"
|
SessionCoordinator_CreateSession_FullMethodName = "/mpc.coordinator.v1.SessionCoordinator/CreateSession"
|
||||||
SessionCoordinator_JoinSession_FullMethodName = "/mpc.coordinator.v1.SessionCoordinator/JoinSession"
|
SessionCoordinator_JoinSession_FullMethodName = "/mpc.coordinator.v1.SessionCoordinator/JoinSession"
|
||||||
SessionCoordinator_GetSessionStatus_FullMethodName = "/mpc.coordinator.v1.SessionCoordinator/GetSessionStatus"
|
SessionCoordinator_GetSessionStatus_FullMethodName = "/mpc.coordinator.v1.SessionCoordinator/GetSessionStatus"
|
||||||
SessionCoordinator_MarkPartyReady_FullMethodName = "/mpc.coordinator.v1.SessionCoordinator/MarkPartyReady"
|
SessionCoordinator_MarkPartyReady_FullMethodName = "/mpc.coordinator.v1.SessionCoordinator/MarkPartyReady"
|
||||||
SessionCoordinator_StartSession_FullMethodName = "/mpc.coordinator.v1.SessionCoordinator/StartSession"
|
SessionCoordinator_StartSession_FullMethodName = "/mpc.coordinator.v1.SessionCoordinator/StartSession"
|
||||||
SessionCoordinator_ReportCompletion_FullMethodName = "/mpc.coordinator.v1.SessionCoordinator/ReportCompletion"
|
SessionCoordinator_ReportCompletion_FullMethodName = "/mpc.coordinator.v1.SessionCoordinator/ReportCompletion"
|
||||||
SessionCoordinator_CloseSession_FullMethodName = "/mpc.coordinator.v1.SessionCoordinator/CloseSession"
|
SessionCoordinator_CloseSession_FullMethodName = "/mpc.coordinator.v1.SessionCoordinator/CloseSession"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SessionCoordinatorClient is the client API for SessionCoordinator service.
|
// SessionCoordinatorClient is the client API for SessionCoordinator service.
|
||||||
//
|
//
|
||||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||||
type SessionCoordinatorClient interface {
|
type SessionCoordinatorClient interface {
|
||||||
// Session management
|
// Session management
|
||||||
CreateSession(ctx context.Context, in *CreateSessionRequest, opts ...grpc.CallOption) (*CreateSessionResponse, error)
|
CreateSession(ctx context.Context, in *CreateSessionRequest, opts ...grpc.CallOption) (*CreateSessionResponse, error)
|
||||||
JoinSession(ctx context.Context, in *JoinSessionRequest, opts ...grpc.CallOption) (*JoinSessionResponse, error)
|
JoinSession(ctx context.Context, in *JoinSessionRequest, opts ...grpc.CallOption) (*JoinSessionResponse, error)
|
||||||
GetSessionStatus(ctx context.Context, in *GetSessionStatusRequest, opts ...grpc.CallOption) (*GetSessionStatusResponse, error)
|
GetSessionStatus(ctx context.Context, in *GetSessionStatusRequest, opts ...grpc.CallOption) (*GetSessionStatusResponse, error)
|
||||||
MarkPartyReady(ctx context.Context, in *MarkPartyReadyRequest, opts ...grpc.CallOption) (*MarkPartyReadyResponse, error)
|
MarkPartyReady(ctx context.Context, in *MarkPartyReadyRequest, opts ...grpc.CallOption) (*MarkPartyReadyResponse, error)
|
||||||
StartSession(ctx context.Context, in *StartSessionRequest, opts ...grpc.CallOption) (*StartSessionResponse, error)
|
StartSession(ctx context.Context, in *StartSessionRequest, opts ...grpc.CallOption) (*StartSessionResponse, error)
|
||||||
ReportCompletion(ctx context.Context, in *ReportCompletionRequest, opts ...grpc.CallOption) (*ReportCompletionResponse, error)
|
ReportCompletion(ctx context.Context, in *ReportCompletionRequest, opts ...grpc.CallOption) (*ReportCompletionResponse, error)
|
||||||
CloseSession(ctx context.Context, in *CloseSessionRequest, opts ...grpc.CallOption) (*CloseSessionResponse, error)
|
CloseSession(ctx context.Context, in *CloseSessionRequest, opts ...grpc.CallOption) (*CloseSessionResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type sessionCoordinatorClient struct {
|
type sessionCoordinatorClient struct {
|
||||||
cc grpc.ClientConnInterface
|
cc grpc.ClientConnInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSessionCoordinatorClient(cc grpc.ClientConnInterface) SessionCoordinatorClient {
|
func NewSessionCoordinatorClient(cc grpc.ClientConnInterface) SessionCoordinatorClient {
|
||||||
return &sessionCoordinatorClient{cc}
|
return &sessionCoordinatorClient{cc}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sessionCoordinatorClient) CreateSession(ctx context.Context, in *CreateSessionRequest, opts ...grpc.CallOption) (*CreateSessionResponse, error) {
|
func (c *sessionCoordinatorClient) CreateSession(ctx context.Context, in *CreateSessionRequest, opts ...grpc.CallOption) (*CreateSessionResponse, error) {
|
||||||
out := new(CreateSessionResponse)
|
out := new(CreateSessionResponse)
|
||||||
err := c.cc.Invoke(ctx, SessionCoordinator_CreateSession_FullMethodName, in, out, opts...)
|
err := c.cc.Invoke(ctx, SessionCoordinator_CreateSession_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sessionCoordinatorClient) JoinSession(ctx context.Context, in *JoinSessionRequest, opts ...grpc.CallOption) (*JoinSessionResponse, error) {
|
func (c *sessionCoordinatorClient) JoinSession(ctx context.Context, in *JoinSessionRequest, opts ...grpc.CallOption) (*JoinSessionResponse, error) {
|
||||||
out := new(JoinSessionResponse)
|
out := new(JoinSessionResponse)
|
||||||
err := c.cc.Invoke(ctx, SessionCoordinator_JoinSession_FullMethodName, in, out, opts...)
|
err := c.cc.Invoke(ctx, SessionCoordinator_JoinSession_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sessionCoordinatorClient) GetSessionStatus(ctx context.Context, in *GetSessionStatusRequest, opts ...grpc.CallOption) (*GetSessionStatusResponse, error) {
|
func (c *sessionCoordinatorClient) GetSessionStatus(ctx context.Context, in *GetSessionStatusRequest, opts ...grpc.CallOption) (*GetSessionStatusResponse, error) {
|
||||||
out := new(GetSessionStatusResponse)
|
out := new(GetSessionStatusResponse)
|
||||||
err := c.cc.Invoke(ctx, SessionCoordinator_GetSessionStatus_FullMethodName, in, out, opts...)
|
err := c.cc.Invoke(ctx, SessionCoordinator_GetSessionStatus_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sessionCoordinatorClient) MarkPartyReady(ctx context.Context, in *MarkPartyReadyRequest, opts ...grpc.CallOption) (*MarkPartyReadyResponse, error) {
|
func (c *sessionCoordinatorClient) MarkPartyReady(ctx context.Context, in *MarkPartyReadyRequest, opts ...grpc.CallOption) (*MarkPartyReadyResponse, error) {
|
||||||
out := new(MarkPartyReadyResponse)
|
out := new(MarkPartyReadyResponse)
|
||||||
err := c.cc.Invoke(ctx, SessionCoordinator_MarkPartyReady_FullMethodName, in, out, opts...)
|
err := c.cc.Invoke(ctx, SessionCoordinator_MarkPartyReady_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sessionCoordinatorClient) StartSession(ctx context.Context, in *StartSessionRequest, opts ...grpc.CallOption) (*StartSessionResponse, error) {
|
func (c *sessionCoordinatorClient) StartSession(ctx context.Context, in *StartSessionRequest, opts ...grpc.CallOption) (*StartSessionResponse, error) {
|
||||||
out := new(StartSessionResponse)
|
out := new(StartSessionResponse)
|
||||||
err := c.cc.Invoke(ctx, SessionCoordinator_StartSession_FullMethodName, in, out, opts...)
|
err := c.cc.Invoke(ctx, SessionCoordinator_StartSession_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sessionCoordinatorClient) ReportCompletion(ctx context.Context, in *ReportCompletionRequest, opts ...grpc.CallOption) (*ReportCompletionResponse, error) {
|
func (c *sessionCoordinatorClient) ReportCompletion(ctx context.Context, in *ReportCompletionRequest, opts ...grpc.CallOption) (*ReportCompletionResponse, error) {
|
||||||
out := new(ReportCompletionResponse)
|
out := new(ReportCompletionResponse)
|
||||||
err := c.cc.Invoke(ctx, SessionCoordinator_ReportCompletion_FullMethodName, in, out, opts...)
|
err := c.cc.Invoke(ctx, SessionCoordinator_ReportCompletion_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sessionCoordinatorClient) CloseSession(ctx context.Context, in *CloseSessionRequest, opts ...grpc.CallOption) (*CloseSessionResponse, error) {
|
func (c *sessionCoordinatorClient) CloseSession(ctx context.Context, in *CloseSessionRequest, opts ...grpc.CallOption) (*CloseSessionResponse, error) {
|
||||||
out := new(CloseSessionResponse)
|
out := new(CloseSessionResponse)
|
||||||
err := c.cc.Invoke(ctx, SessionCoordinator_CloseSession_FullMethodName, in, out, opts...)
|
err := c.cc.Invoke(ctx, SessionCoordinator_CloseSession_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionCoordinatorServer is the server API for SessionCoordinator service.
|
// SessionCoordinatorServer is the server API for SessionCoordinator service.
|
||||||
// All implementations must embed UnimplementedSessionCoordinatorServer
|
// All implementations must embed UnimplementedSessionCoordinatorServer
|
||||||
// for forward compatibility
|
// for forward compatibility
|
||||||
type SessionCoordinatorServer interface {
|
type SessionCoordinatorServer interface {
|
||||||
// Session management
|
// Session management
|
||||||
CreateSession(context.Context, *CreateSessionRequest) (*CreateSessionResponse, error)
|
CreateSession(context.Context, *CreateSessionRequest) (*CreateSessionResponse, error)
|
||||||
JoinSession(context.Context, *JoinSessionRequest) (*JoinSessionResponse, error)
|
JoinSession(context.Context, *JoinSessionRequest) (*JoinSessionResponse, error)
|
||||||
GetSessionStatus(context.Context, *GetSessionStatusRequest) (*GetSessionStatusResponse, error)
|
GetSessionStatus(context.Context, *GetSessionStatusRequest) (*GetSessionStatusResponse, error)
|
||||||
MarkPartyReady(context.Context, *MarkPartyReadyRequest) (*MarkPartyReadyResponse, error)
|
MarkPartyReady(context.Context, *MarkPartyReadyRequest) (*MarkPartyReadyResponse, error)
|
||||||
StartSession(context.Context, *StartSessionRequest) (*StartSessionResponse, error)
|
StartSession(context.Context, *StartSessionRequest) (*StartSessionResponse, error)
|
||||||
ReportCompletion(context.Context, *ReportCompletionRequest) (*ReportCompletionResponse, error)
|
ReportCompletion(context.Context, *ReportCompletionRequest) (*ReportCompletionResponse, error)
|
||||||
CloseSession(context.Context, *CloseSessionRequest) (*CloseSessionResponse, error)
|
CloseSession(context.Context, *CloseSessionRequest) (*CloseSessionResponse, error)
|
||||||
mustEmbedUnimplementedSessionCoordinatorServer()
|
mustEmbedUnimplementedSessionCoordinatorServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnimplementedSessionCoordinatorServer must be embedded to have forward compatible implementations.
|
// UnimplementedSessionCoordinatorServer must be embedded to have forward compatible implementations.
|
||||||
type UnimplementedSessionCoordinatorServer struct {
|
type UnimplementedSessionCoordinatorServer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (UnimplementedSessionCoordinatorServer) CreateSession(context.Context, *CreateSessionRequest) (*CreateSessionResponse, error) {
|
func (UnimplementedSessionCoordinatorServer) CreateSession(context.Context, *CreateSessionRequest) (*CreateSessionResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method CreateSession not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreateSession not implemented")
|
||||||
}
|
}
|
||||||
func (UnimplementedSessionCoordinatorServer) JoinSession(context.Context, *JoinSessionRequest) (*JoinSessionResponse, error) {
|
func (UnimplementedSessionCoordinatorServer) JoinSession(context.Context, *JoinSessionRequest) (*JoinSessionResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method JoinSession not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method JoinSession not implemented")
|
||||||
}
|
}
|
||||||
func (UnimplementedSessionCoordinatorServer) GetSessionStatus(context.Context, *GetSessionStatusRequest) (*GetSessionStatusResponse, error) {
|
func (UnimplementedSessionCoordinatorServer) GetSessionStatus(context.Context, *GetSessionStatusRequest) (*GetSessionStatusResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetSessionStatus not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetSessionStatus not implemented")
|
||||||
}
|
}
|
||||||
func (UnimplementedSessionCoordinatorServer) MarkPartyReady(context.Context, *MarkPartyReadyRequest) (*MarkPartyReadyResponse, error) {
|
func (UnimplementedSessionCoordinatorServer) MarkPartyReady(context.Context, *MarkPartyReadyRequest) (*MarkPartyReadyResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method MarkPartyReady not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method MarkPartyReady not implemented")
|
||||||
}
|
}
|
||||||
func (UnimplementedSessionCoordinatorServer) StartSession(context.Context, *StartSessionRequest) (*StartSessionResponse, error) {
|
func (UnimplementedSessionCoordinatorServer) StartSession(context.Context, *StartSessionRequest) (*StartSessionResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method StartSession not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method StartSession not implemented")
|
||||||
}
|
}
|
||||||
func (UnimplementedSessionCoordinatorServer) ReportCompletion(context.Context, *ReportCompletionRequest) (*ReportCompletionResponse, error) {
|
func (UnimplementedSessionCoordinatorServer) ReportCompletion(context.Context, *ReportCompletionRequest) (*ReportCompletionResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method ReportCompletion not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method ReportCompletion not implemented")
|
||||||
}
|
}
|
||||||
func (UnimplementedSessionCoordinatorServer) CloseSession(context.Context, *CloseSessionRequest) (*CloseSessionResponse, error) {
|
func (UnimplementedSessionCoordinatorServer) CloseSession(context.Context, *CloseSessionRequest) (*CloseSessionResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method CloseSession not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CloseSession not implemented")
|
||||||
}
|
}
|
||||||
func (UnimplementedSessionCoordinatorServer) mustEmbedUnimplementedSessionCoordinatorServer() {}
|
func (UnimplementedSessionCoordinatorServer) mustEmbedUnimplementedSessionCoordinatorServer() {}
|
||||||
|
|
||||||
// UnsafeSessionCoordinatorServer may be embedded to opt out of forward compatibility for this service.
|
// UnsafeSessionCoordinatorServer may be embedded to opt out of forward compatibility for this service.
|
||||||
// Use of this interface is not recommended, as added methods to SessionCoordinatorServer will
|
// Use of this interface is not recommended, as added methods to SessionCoordinatorServer will
|
||||||
// result in compilation errors.
|
// result in compilation errors.
|
||||||
type UnsafeSessionCoordinatorServer interface {
|
type UnsafeSessionCoordinatorServer interface {
|
||||||
mustEmbedUnimplementedSessionCoordinatorServer()
|
mustEmbedUnimplementedSessionCoordinatorServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterSessionCoordinatorServer(s grpc.ServiceRegistrar, srv SessionCoordinatorServer) {
|
func RegisterSessionCoordinatorServer(s grpc.ServiceRegistrar, srv SessionCoordinatorServer) {
|
||||||
s.RegisterService(&SessionCoordinator_ServiceDesc, srv)
|
s.RegisterService(&SessionCoordinator_ServiceDesc, srv)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _SessionCoordinator_CreateSession_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
func _SessionCoordinator_CreateSession_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
in := new(CreateSessionRequest)
|
in := new(CreateSessionRequest)
|
||||||
if err := dec(in); err != nil {
|
if err := dec(in); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if interceptor == nil {
|
if interceptor == nil {
|
||||||
return srv.(SessionCoordinatorServer).CreateSession(ctx, in)
|
return srv.(SessionCoordinatorServer).CreateSession(ctx, in)
|
||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: SessionCoordinator_CreateSession_FullMethodName,
|
FullMethod: SessionCoordinator_CreateSession_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(SessionCoordinatorServer).CreateSession(ctx, req.(*CreateSessionRequest))
|
return srv.(SessionCoordinatorServer).CreateSession(ctx, req.(*CreateSessionRequest))
|
||||||
}
|
}
|
||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _SessionCoordinator_JoinSession_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
func _SessionCoordinator_JoinSession_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
in := new(JoinSessionRequest)
|
in := new(JoinSessionRequest)
|
||||||
if err := dec(in); err != nil {
|
if err := dec(in); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if interceptor == nil {
|
if interceptor == nil {
|
||||||
return srv.(SessionCoordinatorServer).JoinSession(ctx, in)
|
return srv.(SessionCoordinatorServer).JoinSession(ctx, in)
|
||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: SessionCoordinator_JoinSession_FullMethodName,
|
FullMethod: SessionCoordinator_JoinSession_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(SessionCoordinatorServer).JoinSession(ctx, req.(*JoinSessionRequest))
|
return srv.(SessionCoordinatorServer).JoinSession(ctx, req.(*JoinSessionRequest))
|
||||||
}
|
}
|
||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _SessionCoordinator_GetSessionStatus_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
func _SessionCoordinator_GetSessionStatus_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
in := new(GetSessionStatusRequest)
|
in := new(GetSessionStatusRequest)
|
||||||
if err := dec(in); err != nil {
|
if err := dec(in); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if interceptor == nil {
|
if interceptor == nil {
|
||||||
return srv.(SessionCoordinatorServer).GetSessionStatus(ctx, in)
|
return srv.(SessionCoordinatorServer).GetSessionStatus(ctx, in)
|
||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: SessionCoordinator_GetSessionStatus_FullMethodName,
|
FullMethod: SessionCoordinator_GetSessionStatus_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(SessionCoordinatorServer).GetSessionStatus(ctx, req.(*GetSessionStatusRequest))
|
return srv.(SessionCoordinatorServer).GetSessionStatus(ctx, req.(*GetSessionStatusRequest))
|
||||||
}
|
}
|
||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _SessionCoordinator_MarkPartyReady_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
func _SessionCoordinator_MarkPartyReady_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
in := new(MarkPartyReadyRequest)
|
in := new(MarkPartyReadyRequest)
|
||||||
if err := dec(in); err != nil {
|
if err := dec(in); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if interceptor == nil {
|
if interceptor == nil {
|
||||||
return srv.(SessionCoordinatorServer).MarkPartyReady(ctx, in)
|
return srv.(SessionCoordinatorServer).MarkPartyReady(ctx, in)
|
||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: SessionCoordinator_MarkPartyReady_FullMethodName,
|
FullMethod: SessionCoordinator_MarkPartyReady_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(SessionCoordinatorServer).MarkPartyReady(ctx, req.(*MarkPartyReadyRequest))
|
return srv.(SessionCoordinatorServer).MarkPartyReady(ctx, req.(*MarkPartyReadyRequest))
|
||||||
}
|
}
|
||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _SessionCoordinator_StartSession_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
func _SessionCoordinator_StartSession_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
in := new(StartSessionRequest)
|
in := new(StartSessionRequest)
|
||||||
if err := dec(in); err != nil {
|
if err := dec(in); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if interceptor == nil {
|
if interceptor == nil {
|
||||||
return srv.(SessionCoordinatorServer).StartSession(ctx, in)
|
return srv.(SessionCoordinatorServer).StartSession(ctx, in)
|
||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: SessionCoordinator_StartSession_FullMethodName,
|
FullMethod: SessionCoordinator_StartSession_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(SessionCoordinatorServer).StartSession(ctx, req.(*StartSessionRequest))
|
return srv.(SessionCoordinatorServer).StartSession(ctx, req.(*StartSessionRequest))
|
||||||
}
|
}
|
||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _SessionCoordinator_ReportCompletion_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
func _SessionCoordinator_ReportCompletion_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
in := new(ReportCompletionRequest)
|
in := new(ReportCompletionRequest)
|
||||||
if err := dec(in); err != nil {
|
if err := dec(in); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if interceptor == nil {
|
if interceptor == nil {
|
||||||
return srv.(SessionCoordinatorServer).ReportCompletion(ctx, in)
|
return srv.(SessionCoordinatorServer).ReportCompletion(ctx, in)
|
||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: SessionCoordinator_ReportCompletion_FullMethodName,
|
FullMethod: SessionCoordinator_ReportCompletion_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(SessionCoordinatorServer).ReportCompletion(ctx, req.(*ReportCompletionRequest))
|
return srv.(SessionCoordinatorServer).ReportCompletion(ctx, req.(*ReportCompletionRequest))
|
||||||
}
|
}
|
||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _SessionCoordinator_CloseSession_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
func _SessionCoordinator_CloseSession_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
in := new(CloseSessionRequest)
|
in := new(CloseSessionRequest)
|
||||||
if err := dec(in); err != nil {
|
if err := dec(in); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if interceptor == nil {
|
if interceptor == nil {
|
||||||
return srv.(SessionCoordinatorServer).CloseSession(ctx, in)
|
return srv.(SessionCoordinatorServer).CloseSession(ctx, in)
|
||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: SessionCoordinator_CloseSession_FullMethodName,
|
FullMethod: SessionCoordinator_CloseSession_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(SessionCoordinatorServer).CloseSession(ctx, req.(*CloseSessionRequest))
|
return srv.(SessionCoordinatorServer).CloseSession(ctx, req.(*CloseSessionRequest))
|
||||||
}
|
}
|
||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionCoordinator_ServiceDesc is the grpc.ServiceDesc for SessionCoordinator service.
|
// SessionCoordinator_ServiceDesc is the grpc.ServiceDesc for SessionCoordinator service.
|
||||||
// It's only intended for direct use with grpc.RegisterService,
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
// and not to be introspected or modified (even as a copy)
|
// and not to be introspected or modified (even as a copy)
|
||||||
var SessionCoordinator_ServiceDesc = grpc.ServiceDesc{
|
var SessionCoordinator_ServiceDesc = grpc.ServiceDesc{
|
||||||
ServiceName: "mpc.coordinator.v1.SessionCoordinator",
|
ServiceName: "mpc.coordinator.v1.SessionCoordinator",
|
||||||
HandlerType: (*SessionCoordinatorServer)(nil),
|
HandlerType: (*SessionCoordinatorServer)(nil),
|
||||||
Methods: []grpc.MethodDesc{
|
Methods: []grpc.MethodDesc{
|
||||||
{
|
{
|
||||||
MethodName: "CreateSession",
|
MethodName: "CreateSession",
|
||||||
Handler: _SessionCoordinator_CreateSession_Handler,
|
Handler: _SessionCoordinator_CreateSession_Handler,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
MethodName: "JoinSession",
|
MethodName: "JoinSession",
|
||||||
Handler: _SessionCoordinator_JoinSession_Handler,
|
Handler: _SessionCoordinator_JoinSession_Handler,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
MethodName: "GetSessionStatus",
|
MethodName: "GetSessionStatus",
|
||||||
Handler: _SessionCoordinator_GetSessionStatus_Handler,
|
Handler: _SessionCoordinator_GetSessionStatus_Handler,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
MethodName: "MarkPartyReady",
|
MethodName: "MarkPartyReady",
|
||||||
Handler: _SessionCoordinator_MarkPartyReady_Handler,
|
Handler: _SessionCoordinator_MarkPartyReady_Handler,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
MethodName: "StartSession",
|
MethodName: "StartSession",
|
||||||
Handler: _SessionCoordinator_StartSession_Handler,
|
Handler: _SessionCoordinator_StartSession_Handler,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
MethodName: "ReportCompletion",
|
MethodName: "ReportCompletion",
|
||||||
Handler: _SessionCoordinator_ReportCompletion_Handler,
|
Handler: _SessionCoordinator_ReportCompletion_Handler,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
MethodName: "CloseSession",
|
MethodName: "CloseSession",
|
||||||
Handler: _SessionCoordinator_CloseSession_Handler,
|
Handler: _SessionCoordinator_CloseSession_Handler,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Streams: []grpc.StreamDesc{},
|
Streams: []grpc.StreamDesc{},
|
||||||
Metadata: "api/proto/session_coordinator.proto",
|
Metadata: "api/proto/session_coordinator.proto",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,217 +1,217 @@
|
||||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// - protoc-gen-go-grpc v1.3.0
|
// - protoc-gen-go-grpc v1.3.0
|
||||||
// - protoc v3.12.4
|
// - protoc v3.12.4
|
||||||
// source: api/proto/message_router.proto
|
// source: api/proto/message_router.proto
|
||||||
|
|
||||||
package router
|
package router
|
||||||
|
|
||||||
import (
|
import (
|
||||||
context "context"
|
context "context"
|
||||||
grpc "google.golang.org/grpc"
|
grpc "google.golang.org/grpc"
|
||||||
codes "google.golang.org/grpc/codes"
|
codes "google.golang.org/grpc/codes"
|
||||||
status "google.golang.org/grpc/status"
|
status "google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This is a compile-time assertion to ensure that this generated file
|
// This is a compile-time assertion to ensure that this generated file
|
||||||
// is compatible with the grpc package it is being compiled against.
|
// is compatible with the grpc package it is being compiled against.
|
||||||
// Requires gRPC-Go v1.32.0 or later.
|
// Requires gRPC-Go v1.32.0 or later.
|
||||||
const _ = grpc.SupportPackageIsVersion7
|
const _ = grpc.SupportPackageIsVersion7
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MessageRouter_RouteMessage_FullMethodName = "/mpc.router.v1.MessageRouter/RouteMessage"
|
MessageRouter_RouteMessage_FullMethodName = "/mpc.router.v1.MessageRouter/RouteMessage"
|
||||||
MessageRouter_SubscribeMessages_FullMethodName = "/mpc.router.v1.MessageRouter/SubscribeMessages"
|
MessageRouter_SubscribeMessages_FullMethodName = "/mpc.router.v1.MessageRouter/SubscribeMessages"
|
||||||
MessageRouter_GetPendingMessages_FullMethodName = "/mpc.router.v1.MessageRouter/GetPendingMessages"
|
MessageRouter_GetPendingMessages_FullMethodName = "/mpc.router.v1.MessageRouter/GetPendingMessages"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MessageRouterClient is the client API for MessageRouter service.
|
// MessageRouterClient is the client API for MessageRouter service.
|
||||||
//
|
//
|
||||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||||
type MessageRouterClient interface {
|
type MessageRouterClient interface {
|
||||||
// RouteMessage routes a message from one party to others
|
// RouteMessage routes a message from one party to others
|
||||||
RouteMessage(ctx context.Context, in *RouteMessageRequest, opts ...grpc.CallOption) (*RouteMessageResponse, error)
|
RouteMessage(ctx context.Context, in *RouteMessageRequest, opts ...grpc.CallOption) (*RouteMessageResponse, error)
|
||||||
// SubscribeMessages subscribes to messages for a party (streaming)
|
// SubscribeMessages subscribes to messages for a party (streaming)
|
||||||
SubscribeMessages(ctx context.Context, in *SubscribeMessagesRequest, opts ...grpc.CallOption) (MessageRouter_SubscribeMessagesClient, error)
|
SubscribeMessages(ctx context.Context, in *SubscribeMessagesRequest, opts ...grpc.CallOption) (MessageRouter_SubscribeMessagesClient, error)
|
||||||
// GetPendingMessages retrieves pending messages (polling alternative)
|
// GetPendingMessages retrieves pending messages (polling alternative)
|
||||||
GetPendingMessages(ctx context.Context, in *GetPendingMessagesRequest, opts ...grpc.CallOption) (*GetPendingMessagesResponse, error)
|
GetPendingMessages(ctx context.Context, in *GetPendingMessagesRequest, opts ...grpc.CallOption) (*GetPendingMessagesResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type messageRouterClient struct {
|
type messageRouterClient struct {
|
||||||
cc grpc.ClientConnInterface
|
cc grpc.ClientConnInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMessageRouterClient(cc grpc.ClientConnInterface) MessageRouterClient {
|
func NewMessageRouterClient(cc grpc.ClientConnInterface) MessageRouterClient {
|
||||||
return &messageRouterClient{cc}
|
return &messageRouterClient{cc}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *messageRouterClient) RouteMessage(ctx context.Context, in *RouteMessageRequest, opts ...grpc.CallOption) (*RouteMessageResponse, error) {
|
func (c *messageRouterClient) RouteMessage(ctx context.Context, in *RouteMessageRequest, opts ...grpc.CallOption) (*RouteMessageResponse, error) {
|
||||||
out := new(RouteMessageResponse)
|
out := new(RouteMessageResponse)
|
||||||
err := c.cc.Invoke(ctx, MessageRouter_RouteMessage_FullMethodName, in, out, opts...)
|
err := c.cc.Invoke(ctx, MessageRouter_RouteMessage_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *messageRouterClient) SubscribeMessages(ctx context.Context, in *SubscribeMessagesRequest, opts ...grpc.CallOption) (MessageRouter_SubscribeMessagesClient, error) {
|
func (c *messageRouterClient) SubscribeMessages(ctx context.Context, in *SubscribeMessagesRequest, opts ...grpc.CallOption) (MessageRouter_SubscribeMessagesClient, error) {
|
||||||
stream, err := c.cc.NewStream(ctx, &MessageRouter_ServiceDesc.Streams[0], MessageRouter_SubscribeMessages_FullMethodName, opts...)
|
stream, err := c.cc.NewStream(ctx, &MessageRouter_ServiceDesc.Streams[0], MessageRouter_SubscribeMessages_FullMethodName, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
x := &messageRouterSubscribeMessagesClient{stream}
|
x := &messageRouterSubscribeMessagesClient{stream}
|
||||||
if err := x.ClientStream.SendMsg(in); err != nil {
|
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := x.ClientStream.CloseSend(); err != nil {
|
if err := x.ClientStream.CloseSend(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return x, nil
|
return x, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MessageRouter_SubscribeMessagesClient interface {
|
type MessageRouter_SubscribeMessagesClient interface {
|
||||||
Recv() (*MPCMessage, error)
|
Recv() (*MPCMessage, error)
|
||||||
grpc.ClientStream
|
grpc.ClientStream
|
||||||
}
|
}
|
||||||
|
|
||||||
type messageRouterSubscribeMessagesClient struct {
|
type messageRouterSubscribeMessagesClient struct {
|
||||||
grpc.ClientStream
|
grpc.ClientStream
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *messageRouterSubscribeMessagesClient) Recv() (*MPCMessage, error) {
|
func (x *messageRouterSubscribeMessagesClient) Recv() (*MPCMessage, error) {
|
||||||
m := new(MPCMessage)
|
m := new(MPCMessage)
|
||||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *messageRouterClient) GetPendingMessages(ctx context.Context, in *GetPendingMessagesRequest, opts ...grpc.CallOption) (*GetPendingMessagesResponse, error) {
|
func (c *messageRouterClient) GetPendingMessages(ctx context.Context, in *GetPendingMessagesRequest, opts ...grpc.CallOption) (*GetPendingMessagesResponse, error) {
|
||||||
out := new(GetPendingMessagesResponse)
|
out := new(GetPendingMessagesResponse)
|
||||||
err := c.cc.Invoke(ctx, MessageRouter_GetPendingMessages_FullMethodName, in, out, opts...)
|
err := c.cc.Invoke(ctx, MessageRouter_GetPendingMessages_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MessageRouterServer is the server API for MessageRouter service.
|
// MessageRouterServer is the server API for MessageRouter service.
|
||||||
// All implementations must embed UnimplementedMessageRouterServer
|
// All implementations must embed UnimplementedMessageRouterServer
|
||||||
// for forward compatibility
|
// for forward compatibility
|
||||||
type MessageRouterServer interface {
|
type MessageRouterServer interface {
|
||||||
// RouteMessage routes a message from one party to others
|
// RouteMessage routes a message from one party to others
|
||||||
RouteMessage(context.Context, *RouteMessageRequest) (*RouteMessageResponse, error)
|
RouteMessage(context.Context, *RouteMessageRequest) (*RouteMessageResponse, error)
|
||||||
// SubscribeMessages subscribes to messages for a party (streaming)
|
// SubscribeMessages subscribes to messages for a party (streaming)
|
||||||
SubscribeMessages(*SubscribeMessagesRequest, MessageRouter_SubscribeMessagesServer) error
|
SubscribeMessages(*SubscribeMessagesRequest, MessageRouter_SubscribeMessagesServer) error
|
||||||
// GetPendingMessages retrieves pending messages (polling alternative)
|
// GetPendingMessages retrieves pending messages (polling alternative)
|
||||||
GetPendingMessages(context.Context, *GetPendingMessagesRequest) (*GetPendingMessagesResponse, error)
|
GetPendingMessages(context.Context, *GetPendingMessagesRequest) (*GetPendingMessagesResponse, error)
|
||||||
mustEmbedUnimplementedMessageRouterServer()
|
mustEmbedUnimplementedMessageRouterServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnimplementedMessageRouterServer must be embedded to have forward compatible implementations.
|
// UnimplementedMessageRouterServer must be embedded to have forward compatible implementations.
|
||||||
type UnimplementedMessageRouterServer struct {
|
type UnimplementedMessageRouterServer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (UnimplementedMessageRouterServer) RouteMessage(context.Context, *RouteMessageRequest) (*RouteMessageResponse, error) {
|
func (UnimplementedMessageRouterServer) RouteMessage(context.Context, *RouteMessageRequest) (*RouteMessageResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method RouteMessage not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method RouteMessage not implemented")
|
||||||
}
|
}
|
||||||
func (UnimplementedMessageRouterServer) SubscribeMessages(*SubscribeMessagesRequest, MessageRouter_SubscribeMessagesServer) error {
|
func (UnimplementedMessageRouterServer) SubscribeMessages(*SubscribeMessagesRequest, MessageRouter_SubscribeMessagesServer) error {
|
||||||
return status.Errorf(codes.Unimplemented, "method SubscribeMessages not implemented")
|
return status.Errorf(codes.Unimplemented, "method SubscribeMessages not implemented")
|
||||||
}
|
}
|
||||||
func (UnimplementedMessageRouterServer) GetPendingMessages(context.Context, *GetPendingMessagesRequest) (*GetPendingMessagesResponse, error) {
|
func (UnimplementedMessageRouterServer) GetPendingMessages(context.Context, *GetPendingMessagesRequest) (*GetPendingMessagesResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetPendingMessages not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetPendingMessages not implemented")
|
||||||
}
|
}
|
||||||
func (UnimplementedMessageRouterServer) mustEmbedUnimplementedMessageRouterServer() {}
|
func (UnimplementedMessageRouterServer) mustEmbedUnimplementedMessageRouterServer() {}
|
||||||
|
|
||||||
// UnsafeMessageRouterServer may be embedded to opt out of forward compatibility for this service.
|
// UnsafeMessageRouterServer may be embedded to opt out of forward compatibility for this service.
|
||||||
// Use of this interface is not recommended, as added methods to MessageRouterServer will
|
// Use of this interface is not recommended, as added methods to MessageRouterServer will
|
||||||
// result in compilation errors.
|
// result in compilation errors.
|
||||||
type UnsafeMessageRouterServer interface {
|
type UnsafeMessageRouterServer interface {
|
||||||
mustEmbedUnimplementedMessageRouterServer()
|
mustEmbedUnimplementedMessageRouterServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterMessageRouterServer(s grpc.ServiceRegistrar, srv MessageRouterServer) {
|
func RegisterMessageRouterServer(s grpc.ServiceRegistrar, srv MessageRouterServer) {
|
||||||
s.RegisterService(&MessageRouter_ServiceDesc, srv)
|
s.RegisterService(&MessageRouter_ServiceDesc, srv)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _MessageRouter_RouteMessage_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
func _MessageRouter_RouteMessage_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
in := new(RouteMessageRequest)
|
in := new(RouteMessageRequest)
|
||||||
if err := dec(in); err != nil {
|
if err := dec(in); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if interceptor == nil {
|
if interceptor == nil {
|
||||||
return srv.(MessageRouterServer).RouteMessage(ctx, in)
|
return srv.(MessageRouterServer).RouteMessage(ctx, in)
|
||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: MessageRouter_RouteMessage_FullMethodName,
|
FullMethod: MessageRouter_RouteMessage_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(MessageRouterServer).RouteMessage(ctx, req.(*RouteMessageRequest))
|
return srv.(MessageRouterServer).RouteMessage(ctx, req.(*RouteMessageRequest))
|
||||||
}
|
}
|
||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _MessageRouter_SubscribeMessages_Handler(srv interface{}, stream grpc.ServerStream) error {
|
func _MessageRouter_SubscribeMessages_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||||
m := new(SubscribeMessagesRequest)
|
m := new(SubscribeMessagesRequest)
|
||||||
if err := stream.RecvMsg(m); err != nil {
|
if err := stream.RecvMsg(m); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return srv.(MessageRouterServer).SubscribeMessages(m, &messageRouterSubscribeMessagesServer{stream})
|
return srv.(MessageRouterServer).SubscribeMessages(m, &messageRouterSubscribeMessagesServer{stream})
|
||||||
}
|
}
|
||||||
|
|
||||||
type MessageRouter_SubscribeMessagesServer interface {
|
type MessageRouter_SubscribeMessagesServer interface {
|
||||||
Send(*MPCMessage) error
|
Send(*MPCMessage) error
|
||||||
grpc.ServerStream
|
grpc.ServerStream
|
||||||
}
|
}
|
||||||
|
|
||||||
type messageRouterSubscribeMessagesServer struct {
|
type messageRouterSubscribeMessagesServer struct {
|
||||||
grpc.ServerStream
|
grpc.ServerStream
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *messageRouterSubscribeMessagesServer) Send(m *MPCMessage) error {
|
func (x *messageRouterSubscribeMessagesServer) Send(m *MPCMessage) error {
|
||||||
return x.ServerStream.SendMsg(m)
|
return x.ServerStream.SendMsg(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _MessageRouter_GetPendingMessages_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
func _MessageRouter_GetPendingMessages_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
in := new(GetPendingMessagesRequest)
|
in := new(GetPendingMessagesRequest)
|
||||||
if err := dec(in); err != nil {
|
if err := dec(in); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if interceptor == nil {
|
if interceptor == nil {
|
||||||
return srv.(MessageRouterServer).GetPendingMessages(ctx, in)
|
return srv.(MessageRouterServer).GetPendingMessages(ctx, in)
|
||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: MessageRouter_GetPendingMessages_FullMethodName,
|
FullMethod: MessageRouter_GetPendingMessages_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(MessageRouterServer).GetPendingMessages(ctx, req.(*GetPendingMessagesRequest))
|
return srv.(MessageRouterServer).GetPendingMessages(ctx, req.(*GetPendingMessagesRequest))
|
||||||
}
|
}
|
||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MessageRouter_ServiceDesc is the grpc.ServiceDesc for MessageRouter service.
|
// MessageRouter_ServiceDesc is the grpc.ServiceDesc for MessageRouter service.
|
||||||
// It's only intended for direct use with grpc.RegisterService,
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
// and not to be introspected or modified (even as a copy)
|
// and not to be introspected or modified (even as a copy)
|
||||||
var MessageRouter_ServiceDesc = grpc.ServiceDesc{
|
var MessageRouter_ServiceDesc = grpc.ServiceDesc{
|
||||||
ServiceName: "mpc.router.v1.MessageRouter",
|
ServiceName: "mpc.router.v1.MessageRouter",
|
||||||
HandlerType: (*MessageRouterServer)(nil),
|
HandlerType: (*MessageRouterServer)(nil),
|
||||||
Methods: []grpc.MethodDesc{
|
Methods: []grpc.MethodDesc{
|
||||||
{
|
{
|
||||||
MethodName: "RouteMessage",
|
MethodName: "RouteMessage",
|
||||||
Handler: _MessageRouter_RouteMessage_Handler,
|
Handler: _MessageRouter_RouteMessage_Handler,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
MethodName: "GetPendingMessages",
|
MethodName: "GetPendingMessages",
|
||||||
Handler: _MessageRouter_GetPendingMessages_Handler,
|
Handler: _MessageRouter_GetPendingMessages_Handler,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Streams: []grpc.StreamDesc{
|
Streams: []grpc.StreamDesc{
|
||||||
{
|
{
|
||||||
StreamName: "SubscribeMessages",
|
StreamName: "SubscribeMessages",
|
||||||
Handler: _MessageRouter_SubscribeMessages_Handler,
|
Handler: _MessageRouter_SubscribeMessages_Handler,
|
||||||
ServerStreams: true,
|
ServerStreams: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Metadata: "api/proto/message_router.proto",
|
Metadata: "api/proto/message_router.proto",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,63 +1,103 @@
|
||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
|
|
||||||
package mpc.router.v1;
|
package mpc.router.v1;
|
||||||
|
|
||||||
option go_package = "github.com/rwadurian/mpc-system/api/grpc/router/v1;router";
|
option go_package = "github.com/rwadurian/mpc-system/api/grpc/router/v1;router";
|
||||||
|
|
||||||
// MessageRouter service handles MPC message routing
|
// MessageRouter service handles MPC message routing
|
||||||
service MessageRouter {
|
service MessageRouter {
|
||||||
// RouteMessage routes a message from one party to others
|
// RouteMessage routes a message from one party to others
|
||||||
rpc RouteMessage(RouteMessageRequest) returns (RouteMessageResponse);
|
rpc RouteMessage(RouteMessageRequest) returns (RouteMessageResponse);
|
||||||
|
|
||||||
// SubscribeMessages subscribes to messages for a party (streaming)
|
// SubscribeMessages subscribes to messages for a party (streaming)
|
||||||
rpc SubscribeMessages(SubscribeMessagesRequest) returns (stream MPCMessage);
|
rpc SubscribeMessages(SubscribeMessagesRequest) returns (stream MPCMessage);
|
||||||
|
|
||||||
// GetPendingMessages retrieves pending messages (polling alternative)
|
// GetPendingMessages retrieves pending messages (polling alternative)
|
||||||
rpc GetPendingMessages(GetPendingMessagesRequest) returns (GetPendingMessagesResponse);
|
rpc GetPendingMessages(GetPendingMessagesRequest) returns (GetPendingMessagesResponse);
|
||||||
}
|
|
||||||
|
// RegisterParty registers a party with the message router (party actively connects)
|
||||||
// RouteMessageRequest routes an MPC message
|
rpc RegisterParty(RegisterPartyRequest) returns (RegisterPartyResponse);
|
||||||
message RouteMessageRequest {
|
|
||||||
string session_id = 1;
|
// SubscribeSessionEvents subscribes to session lifecycle events (session start, etc.)
|
||||||
string from_party = 2;
|
rpc SubscribeSessionEvents(SubscribeSessionEventsRequest) returns (stream SessionEvent);
|
||||||
repeated string to_parties = 3; // Empty for broadcast
|
}
|
||||||
int32 round_number = 4;
|
|
||||||
string message_type = 5;
|
// RouteMessageRequest routes an MPC message
|
||||||
bytes payload = 6; // Encrypted MPC message
|
message RouteMessageRequest {
|
||||||
}
|
string session_id = 1;
|
||||||
|
string from_party = 2;
|
||||||
// RouteMessageResponse confirms message routing
|
repeated string to_parties = 3; // Empty for broadcast
|
||||||
message RouteMessageResponse {
|
int32 round_number = 4;
|
||||||
bool success = 1;
|
string message_type = 5;
|
||||||
string message_id = 2;
|
bytes payload = 6; // Encrypted MPC message
|
||||||
}
|
}
|
||||||
|
|
||||||
// SubscribeMessagesRequest subscribes to messages for a party
|
// RouteMessageResponse confirms message routing
|
||||||
message SubscribeMessagesRequest {
|
message RouteMessageResponse {
|
||||||
string session_id = 1;
|
bool success = 1;
|
||||||
string party_id = 2;
|
string message_id = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// MPCMessage represents an MPC protocol message
|
// SubscribeMessagesRequest subscribes to messages for a party
|
||||||
message MPCMessage {
|
message SubscribeMessagesRequest {
|
||||||
string message_id = 1;
|
string session_id = 1;
|
||||||
string session_id = 2;
|
string party_id = 2;
|
||||||
string from_party = 3;
|
}
|
||||||
bool is_broadcast = 4;
|
|
||||||
int32 round_number = 5;
|
// MPCMessage represents an MPC protocol message
|
||||||
string message_type = 6;
|
message MPCMessage {
|
||||||
bytes payload = 7;
|
string message_id = 1;
|
||||||
int64 created_at = 8; // Unix timestamp milliseconds
|
string session_id = 2;
|
||||||
}
|
string from_party = 3;
|
||||||
|
bool is_broadcast = 4;
|
||||||
// GetPendingMessagesRequest retrieves pending messages
|
int32 round_number = 5;
|
||||||
message GetPendingMessagesRequest {
|
string message_type = 6;
|
||||||
string session_id = 1;
|
bytes payload = 7;
|
||||||
string party_id = 2;
|
int64 created_at = 8; // Unix timestamp milliseconds
|
||||||
int64 after_timestamp = 3; // Get messages after this timestamp
|
}
|
||||||
}
|
|
||||||
|
// GetPendingMessagesRequest retrieves pending messages
|
||||||
// GetPendingMessagesResponse contains pending messages
|
message GetPendingMessagesRequest {
|
||||||
message GetPendingMessagesResponse {
|
string session_id = 1;
|
||||||
repeated MPCMessage messages = 1;
|
string party_id = 2;
|
||||||
}
|
int64 after_timestamp = 3; // Get messages after this timestamp
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPendingMessagesResponse contains pending messages
|
||||||
|
message GetPendingMessagesResponse {
|
||||||
|
repeated MPCMessage messages = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterPartyRequest registers a party with the router
|
||||||
|
message RegisterPartyRequest {
|
||||||
|
string party_id = 1; // Unique party identifier
|
||||||
|
string party_role = 2; // persistent, delegate, or temporary
|
||||||
|
string version = 3; // Party software version
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterPartyResponse confirms party registration
|
||||||
|
message RegisterPartyResponse {
|
||||||
|
bool success = 1;
|
||||||
|
string message = 2;
|
||||||
|
int64 registered_at = 3; // Unix timestamp milliseconds
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubscribeSessionEventsRequest subscribes to session events
|
||||||
|
message SubscribeSessionEventsRequest {
|
||||||
|
string party_id = 1; // Party ID subscribing to events
|
||||||
|
repeated string event_types = 2; // Event types to subscribe (empty = all)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionEvent represents a session lifecycle event
|
||||||
|
message SessionEvent {
|
||||||
|
string event_id = 1;
|
||||||
|
string event_type = 2; // session_created, session_started, etc.
|
||||||
|
string session_id = 3;
|
||||||
|
int32 threshold_n = 4;
|
||||||
|
int32 threshold_t = 5;
|
||||||
|
repeated string selected_parties = 6; // PartyIDs selected for this session
|
||||||
|
map<string, string> join_tokens = 7; // PartyID -> JoinToken mapping
|
||||||
|
bytes message_hash = 8; // For sign sessions
|
||||||
|
int64 created_at = 9; // Unix timestamp milliseconds
|
||||||
|
int64 expires_at = 10; // Unix timestamp milliseconds
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,143 +1,143 @@
|
||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
|
|
||||||
package mpc.coordinator.v1;
|
package mpc.coordinator.v1;
|
||||||
|
|
||||||
option go_package = "github.com/rwadurian/mpc-system/api/grpc/coordinator/v1;coordinator";
|
option go_package = "github.com/rwadurian/mpc-system/api/grpc/coordinator/v1;coordinator";
|
||||||
|
|
||||||
// SessionCoordinator service manages MPC sessions
|
// SessionCoordinator service manages MPC sessions
|
||||||
service SessionCoordinator {
|
service SessionCoordinator {
|
||||||
// Session management
|
// Session management
|
||||||
rpc CreateSession(CreateSessionRequest) returns (CreateSessionResponse);
|
rpc CreateSession(CreateSessionRequest) returns (CreateSessionResponse);
|
||||||
rpc JoinSession(JoinSessionRequest) returns (JoinSessionResponse);
|
rpc JoinSession(JoinSessionRequest) returns (JoinSessionResponse);
|
||||||
rpc GetSessionStatus(GetSessionStatusRequest) returns (GetSessionStatusResponse);
|
rpc GetSessionStatus(GetSessionStatusRequest) returns (GetSessionStatusResponse);
|
||||||
rpc MarkPartyReady(MarkPartyReadyRequest) returns (MarkPartyReadyResponse);
|
rpc MarkPartyReady(MarkPartyReadyRequest) returns (MarkPartyReadyResponse);
|
||||||
rpc StartSession(StartSessionRequest) returns (StartSessionResponse);
|
rpc StartSession(StartSessionRequest) returns (StartSessionResponse);
|
||||||
rpc ReportCompletion(ReportCompletionRequest) returns (ReportCompletionResponse);
|
rpc ReportCompletion(ReportCompletionRequest) returns (ReportCompletionResponse);
|
||||||
rpc CloseSession(CloseSessionRequest) returns (CloseSessionResponse);
|
rpc CloseSession(CloseSessionRequest) returns (CloseSessionResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateSessionRequest creates a new MPC session
|
// CreateSessionRequest creates a new MPC session
|
||||||
message CreateSessionRequest {
|
message CreateSessionRequest {
|
||||||
string session_type = 1; // "keygen" or "sign"
|
string session_type = 1; // "keygen" or "sign"
|
||||||
int32 threshold_n = 2; // Total number of parties
|
int32 threshold_n = 2; // Total number of parties
|
||||||
int32 threshold_t = 3; // Minimum required parties
|
int32 threshold_t = 3; // Minimum required parties
|
||||||
repeated ParticipantInfo participants = 4;
|
repeated ParticipantInfo participants = 4;
|
||||||
bytes message_hash = 5; // Required for sign sessions
|
bytes message_hash = 5; // Required for sign sessions
|
||||||
int64 expires_in_seconds = 6; // Session expiration time
|
int64 expires_in_seconds = 6; // Session expiration time
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParticipantInfo contains information about a participant
|
// ParticipantInfo contains information about a participant
|
||||||
message ParticipantInfo {
|
message ParticipantInfo {
|
||||||
string party_id = 1;
|
string party_id = 1;
|
||||||
DeviceInfo device_info = 2;
|
DeviceInfo device_info = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceInfo contains device information
|
// DeviceInfo contains device information
|
||||||
message DeviceInfo {
|
message DeviceInfo {
|
||||||
string device_type = 1; // android, ios, pc, server, recovery
|
string device_type = 1; // android, ios, pc, server, recovery
|
||||||
string device_id = 2;
|
string device_id = 2;
|
||||||
string platform = 3;
|
string platform = 3;
|
||||||
string app_version = 4;
|
string app_version = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateSessionResponse contains the created session info
|
// CreateSessionResponse contains the created session info
|
||||||
message CreateSessionResponse {
|
message CreateSessionResponse {
|
||||||
string session_id = 1;
|
string session_id = 1;
|
||||||
map<string, string> join_tokens = 2; // party_id -> join_token
|
map<string, string> join_tokens = 2; // party_id -> join_token
|
||||||
int64 expires_at = 3; // Unix timestamp milliseconds
|
int64 expires_at = 3; // Unix timestamp milliseconds
|
||||||
}
|
}
|
||||||
|
|
||||||
// JoinSessionRequest allows a participant to join a session
|
// JoinSessionRequest allows a participant to join a session
|
||||||
message JoinSessionRequest {
|
message JoinSessionRequest {
|
||||||
string session_id = 1;
|
string session_id = 1;
|
||||||
string party_id = 2;
|
string party_id = 2;
|
||||||
string join_token = 3;
|
string join_token = 3;
|
||||||
DeviceInfo device_info = 4;
|
DeviceInfo device_info = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
// JoinSessionResponse contains session information for the joining party
|
// JoinSessionResponse contains session information for the joining party
|
||||||
message JoinSessionResponse {
|
message JoinSessionResponse {
|
||||||
bool success = 1;
|
bool success = 1;
|
||||||
SessionInfo session_info = 2;
|
SessionInfo session_info = 2;
|
||||||
repeated PartyInfo other_parties = 3;
|
repeated PartyInfo other_parties = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionInfo contains session information
|
// SessionInfo contains session information
|
||||||
message SessionInfo {
|
message SessionInfo {
|
||||||
string session_id = 1;
|
string session_id = 1;
|
||||||
string session_type = 2;
|
string session_type = 2;
|
||||||
int32 threshold_n = 3;
|
int32 threshold_n = 3;
|
||||||
int32 threshold_t = 4;
|
int32 threshold_t = 4;
|
||||||
bytes message_hash = 5;
|
bytes message_hash = 5;
|
||||||
string status = 6;
|
string status = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
// PartyInfo contains party information
|
// PartyInfo contains party information
|
||||||
message PartyInfo {
|
message PartyInfo {
|
||||||
string party_id = 1;
|
string party_id = 1;
|
||||||
int32 party_index = 2;
|
int32 party_index = 2;
|
||||||
DeviceInfo device_info = 3;
|
DeviceInfo device_info = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSessionStatusRequest queries session status
|
// GetSessionStatusRequest queries session status
|
||||||
message GetSessionStatusRequest {
|
message GetSessionStatusRequest {
|
||||||
string session_id = 1;
|
string session_id = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSessionStatusResponse contains session status
|
// GetSessionStatusResponse contains session status
|
||||||
message GetSessionStatusResponse {
|
message GetSessionStatusResponse {
|
||||||
string status = 1;
|
string status = 1;
|
||||||
int32 completed_parties = 2;
|
int32 completed_parties = 2;
|
||||||
int32 total_parties = 3;
|
int32 total_parties = 3;
|
||||||
bytes public_key = 4; // For completed keygen
|
bytes public_key = 4; // For completed keygen
|
||||||
bytes signature = 5; // For completed sign
|
bytes signature = 5; // For completed sign
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReportCompletionRequest reports that a participant has completed
|
// ReportCompletionRequest reports that a participant has completed
|
||||||
message ReportCompletionRequest {
|
message ReportCompletionRequest {
|
||||||
string session_id = 1;
|
string session_id = 1;
|
||||||
string party_id = 2;
|
string party_id = 2;
|
||||||
bytes public_key = 3; // For keygen completion
|
bytes public_key = 3; // For keygen completion
|
||||||
bytes signature = 4; // For sign completion
|
bytes signature = 4; // For sign completion
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReportCompletionResponse contains the result of completion report
|
// ReportCompletionResponse contains the result of completion report
|
||||||
message ReportCompletionResponse {
|
message ReportCompletionResponse {
|
||||||
bool success = 1;
|
bool success = 1;
|
||||||
bool all_completed = 2;
|
bool all_completed = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseSessionRequest closes a session
|
// CloseSessionRequest closes a session
|
||||||
message CloseSessionRequest {
|
message CloseSessionRequest {
|
||||||
string session_id = 1;
|
string session_id = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseSessionResponse contains the result of session closure
|
// CloseSessionResponse contains the result of session closure
|
||||||
message CloseSessionResponse {
|
message CloseSessionResponse {
|
||||||
bool success = 1;
|
bool success = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkPartyReadyRequest marks a party as ready to start the protocol
|
// MarkPartyReadyRequest marks a party as ready to start the protocol
|
||||||
message MarkPartyReadyRequest {
|
message MarkPartyReadyRequest {
|
||||||
string session_id = 1;
|
string session_id = 1;
|
||||||
string party_id = 2;
|
string party_id = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkPartyReadyResponse contains the result of marking party ready
|
// MarkPartyReadyResponse contains the result of marking party ready
|
||||||
message MarkPartyReadyResponse {
|
message MarkPartyReadyResponse {
|
||||||
bool success = 1;
|
bool success = 1;
|
||||||
bool all_ready = 2; // True if all parties are ready
|
bool all_ready = 2; // True if all parties are ready
|
||||||
int32 ready_count = 3;
|
int32 ready_count = 3;
|
||||||
int32 total_parties = 4;
|
int32 total_parties = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartSessionRequest starts the MPC protocol execution
|
// StartSessionRequest starts the MPC protocol execution
|
||||||
message StartSessionRequest {
|
message StartSessionRequest {
|
||||||
string session_id = 1;
|
string session_id = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartSessionResponse contains the result of starting the session
|
// StartSessionResponse contains the result of starting the session
|
||||||
message StartSessionResponse {
|
message StartSessionResponse {
|
||||||
bool success = 1;
|
bool success = 1;
|
||||||
string status = 2; // New session status
|
string status = 2; // New session status
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,69 +1,69 @@
|
||||||
# MPC System Configuration Example
|
# MPC System Configuration Example
|
||||||
# Copy this file to config.yaml and modify as needed
|
# Copy this file to config.yaml and modify as needed
|
||||||
|
|
||||||
# Server configuration
|
# Server configuration
|
||||||
server:
|
server:
|
||||||
grpc_port: 50051
|
grpc_port: 50051
|
||||||
http_port: 8080
|
http_port: 8080
|
||||||
environment: development # development, staging, production
|
environment: development # development, staging, production
|
||||||
timeout: 30s
|
timeout: 30s
|
||||||
tls_enabled: false
|
tls_enabled: false
|
||||||
tls_cert_file: ""
|
tls_cert_file: ""
|
||||||
tls_key_file: ""
|
tls_key_file: ""
|
||||||
|
|
||||||
# Database configuration (PostgreSQL)
|
# Database configuration (PostgreSQL)
|
||||||
database:
|
database:
|
||||||
host: localhost
|
host: localhost
|
||||||
port: 5432
|
port: 5432
|
||||||
user: mpc_user
|
user: mpc_user
|
||||||
password: mpc_secret_password
|
password: mpc_secret_password
|
||||||
dbname: mpc_system
|
dbname: mpc_system
|
||||||
sslmode: disable # disable, require, verify-ca, verify-full
|
sslmode: disable # disable, require, verify-ca, verify-full
|
||||||
max_open_conns: 25
|
max_open_conns: 25
|
||||||
max_idle_conns: 5
|
max_idle_conns: 5
|
||||||
conn_max_life: 5m
|
conn_max_life: 5m
|
||||||
|
|
||||||
# Redis configuration
|
# Redis configuration
|
||||||
redis:
|
redis:
|
||||||
host: localhost
|
host: localhost
|
||||||
port: 6379
|
port: 6379
|
||||||
password: ""
|
password: ""
|
||||||
db: 0
|
db: 0
|
||||||
|
|
||||||
# RabbitMQ configuration
|
# RabbitMQ configuration
|
||||||
rabbitmq:
|
rabbitmq:
|
||||||
host: localhost
|
host: localhost
|
||||||
port: 5672
|
port: 5672
|
||||||
user: mpc_user
|
user: mpc_user
|
||||||
password: mpc_rabbit_password
|
password: mpc_rabbit_password
|
||||||
vhost: /
|
vhost: /
|
||||||
|
|
||||||
# Consul configuration (optional, for service discovery)
|
# Consul configuration (optional, for service discovery)
|
||||||
consul:
|
consul:
|
||||||
host: localhost
|
host: localhost
|
||||||
port: 8500
|
port: 8500
|
||||||
service_id: ""
|
service_id: ""
|
||||||
tags: []
|
tags: []
|
||||||
|
|
||||||
# JWT configuration
|
# JWT configuration
|
||||||
jwt:
|
jwt:
|
||||||
secret_key: "change-this-to-a-secure-random-string-in-production"
|
secret_key: "change-this-to-a-secure-random-string-in-production"
|
||||||
issuer: mpc-system
|
issuer: mpc-system
|
||||||
token_expiry: 15m
|
token_expiry: 15m
|
||||||
refresh_expiry: 24h
|
refresh_expiry: 24h
|
||||||
|
|
||||||
# MPC configuration
|
# MPC configuration
|
||||||
mpc:
|
mpc:
|
||||||
default_threshold_n: 3
|
default_threshold_n: 3
|
||||||
default_threshold_t: 2
|
default_threshold_t: 2
|
||||||
session_timeout: 10m
|
session_timeout: 10m
|
||||||
message_timeout: 30s
|
message_timeout: 30s
|
||||||
keygen_timeout: 10m
|
keygen_timeout: 10m
|
||||||
signing_timeout: 5m
|
signing_timeout: 5m
|
||||||
max_parties: 10
|
max_parties: 10
|
||||||
|
|
||||||
# Logger configuration
|
# Logger configuration
|
||||||
logger:
|
logger:
|
||||||
level: info # debug, info, warn, error
|
level: info # debug, info, warn, error
|
||||||
encoding: json # json, console
|
encoding: json # json, console
|
||||||
output_path: stdout
|
output_path: stdout
|
||||||
|
|
|
||||||
|
|
@ -1,243 +1,243 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# MPC System - Deployment Script
|
# MPC System - Deployment Script
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# This script manages the MPC System Docker services
|
# This script manages the MPC System Docker services
|
||||||
#
|
#
|
||||||
# External Ports:
|
# External Ports:
|
||||||
# 4000 - Account Service HTTP API
|
# 4000 - Account Service HTTP API
|
||||||
# 8081 - Session Coordinator API
|
# 8081 - Session Coordinator API
|
||||||
# 8082 - Message Router WebSocket
|
# 8082 - Message Router WebSocket
|
||||||
# 8083 - Server Party API (user share generation)
|
# 8083 - Server Party API (user share generation)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
# Colors
|
# Colors
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
GREEN='\033[0;32m'
|
GREEN='\033[0;32m'
|
||||||
YELLOW='\033[1;33m'
|
YELLOW='\033[1;33m'
|
||||||
BLUE='\033[0;34m'
|
BLUE='\033[0;34m'
|
||||||
NC='\033[0m'
|
NC='\033[0m'
|
||||||
|
|
||||||
log_info() { echo -e "${BLUE}[INFO]${NC} $1"; }
|
log_info() { echo -e "${BLUE}[INFO]${NC} $1"; }
|
||||||
log_success() { echo -e "${GREEN}[OK]${NC} $1"; }
|
log_success() { echo -e "${GREEN}[OK]${NC} $1"; }
|
||||||
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
||||||
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
||||||
|
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
cd "$SCRIPT_DIR"
|
cd "$SCRIPT_DIR"
|
||||||
|
|
||||||
# Load environment
|
# Load environment
|
||||||
if [ -f ".env" ]; then
|
if [ -f ".env" ]; then
|
||||||
log_info "Loading environment from .env file"
|
log_info "Loading environment from .env file"
|
||||||
set -a
|
set -a
|
||||||
source .env
|
source .env
|
||||||
set +a
|
set +a
|
||||||
elif [ ! -f ".env" ] && [ -f ".env.example" ]; then
|
elif [ ! -f ".env" ] && [ -f ".env.example" ]; then
|
||||||
log_warn ".env file not found. Creating from .env.example"
|
log_warn ".env file not found. Creating from .env.example"
|
||||||
log_warn "Please edit .env and configure for your environment!"
|
log_warn "Please edit .env and configure for your environment!"
|
||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
log_error "Please configure .env file and run again"
|
log_error "Please configure .env file and run again"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Core services list
|
# Core services list
|
||||||
CORE_SERVICES="postgres redis rabbitmq"
|
CORE_SERVICES="postgres redis rabbitmq"
|
||||||
MPC_SERVICES="session-coordinator message-router server-party-1 server-party-2 server-party-3 server-party-api account-service"
|
MPC_SERVICES="session-coordinator message-router server-party-1 server-party-2 server-party-3 server-party-api account-service"
|
||||||
ALL_SERVICES="$CORE_SERVICES $MPC_SERVICES"
|
ALL_SERVICES="$CORE_SERVICES $MPC_SERVICES"
|
||||||
|
|
||||||
case "$1" in
|
case "$1" in
|
||||||
build)
|
build)
|
||||||
log_info "Building MPC System services..."
|
log_info "Building MPC System services..."
|
||||||
docker compose build
|
docker compose build
|
||||||
log_success "MPC System built successfully"
|
log_success "MPC System built successfully"
|
||||||
;;
|
;;
|
||||||
|
|
||||||
build-no-cache)
|
build-no-cache)
|
||||||
log_info "Building MPC System (no cache)..."
|
log_info "Building MPC System (no cache)..."
|
||||||
docker compose build --no-cache
|
docker compose build --no-cache
|
||||||
log_success "MPC System built successfully"
|
log_success "MPC System built successfully"
|
||||||
;;
|
;;
|
||||||
|
|
||||||
up|start)
|
up|start)
|
||||||
log_info "Starting MPC System..."
|
log_info "Starting MPC System..."
|
||||||
docker compose up -d
|
docker compose up -d
|
||||||
log_success "MPC System started"
|
log_success "MPC System started"
|
||||||
echo ""
|
echo ""
|
||||||
log_info "Services status:"
|
log_info "Services status:"
|
||||||
docker compose ps
|
docker compose ps
|
||||||
;;
|
;;
|
||||||
|
|
||||||
down|stop)
|
down|stop)
|
||||||
log_info "Stopping MPC System..."
|
log_info "Stopping MPC System..."
|
||||||
docker compose down
|
docker compose down
|
||||||
log_success "MPC System stopped"
|
log_success "MPC System stopped"
|
||||||
;;
|
;;
|
||||||
|
|
||||||
restart)
|
restart)
|
||||||
log_info "Restarting MPC System..."
|
log_info "Restarting MPC System..."
|
||||||
docker compose down
|
docker compose down
|
||||||
docker compose up -d
|
docker compose up -d
|
||||||
log_success "MPC System restarted"
|
log_success "MPC System restarted"
|
||||||
;;
|
;;
|
||||||
|
|
||||||
logs)
|
logs)
|
||||||
if [ -n "$2" ]; then
|
if [ -n "$2" ]; then
|
||||||
docker compose logs -f "$2"
|
docker compose logs -f "$2"
|
||||||
else
|
else
|
||||||
docker compose logs -f
|
docker compose logs -f
|
||||||
fi
|
fi
|
||||||
;;
|
;;
|
||||||
|
|
||||||
logs-tail)
|
logs-tail)
|
||||||
if [ -n "$2" ]; then
|
if [ -n "$2" ]; then
|
||||||
docker compose logs --tail 100 "$2"
|
docker compose logs --tail 100 "$2"
|
||||||
else
|
else
|
||||||
docker compose logs --tail 100
|
docker compose logs --tail 100
|
||||||
fi
|
fi
|
||||||
;;
|
;;
|
||||||
|
|
||||||
status|ps)
|
status|ps)
|
||||||
log_info "MPC System status:"
|
log_info "MPC System status:"
|
||||||
docker compose ps
|
docker compose ps
|
||||||
;;
|
;;
|
||||||
|
|
||||||
health)
|
health)
|
||||||
log_info "Checking MPC System health..."
|
log_info "Checking MPC System health..."
|
||||||
|
|
||||||
# Check infrastructure
|
# Check infrastructure
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== Infrastructure ==="
|
echo "=== Infrastructure ==="
|
||||||
for svc in $CORE_SERVICES; do
|
for svc in $CORE_SERVICES; do
|
||||||
if docker compose ps "$svc" --format json 2>/dev/null | grep -q '"Health":"healthy"'; then
|
if docker compose ps "$svc" --format json 2>/dev/null | grep -q '"Health":"healthy"'; then
|
||||||
log_success "$svc is healthy"
|
log_success "$svc is healthy"
|
||||||
else
|
else
|
||||||
log_warn "$svc is not healthy"
|
log_warn "$svc is not healthy"
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
|
||||||
# Check MPC services
|
# Check MPC services
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== MPC Services ==="
|
echo "=== MPC Services ==="
|
||||||
for svc in $MPC_SERVICES; do
|
for svc in $MPC_SERVICES; do
|
||||||
if docker compose ps "$svc" --format json 2>/dev/null | grep -q '"Health":"healthy"'; then
|
if docker compose ps "$svc" --format json 2>/dev/null | grep -q '"Health":"healthy"'; then
|
||||||
log_success "$svc is healthy"
|
log_success "$svc is healthy"
|
||||||
else
|
else
|
||||||
log_warn "$svc is not healthy"
|
log_warn "$svc is not healthy"
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
|
||||||
# Check external API
|
# Check external API
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== External API ==="
|
echo "=== External API ==="
|
||||||
if curl -sf "http://localhost:4000/health" > /dev/null 2>&1; then
|
if curl -sf "http://localhost:4000/health" > /dev/null 2>&1; then
|
||||||
log_success "Account Service API (port 4000) is accessible"
|
log_success "Account Service API (port 4000) is accessible"
|
||||||
else
|
else
|
||||||
log_error "Account Service API (port 4000) is not accessible"
|
log_error "Account Service API (port 4000) is not accessible"
|
||||||
fi
|
fi
|
||||||
;;
|
;;
|
||||||
|
|
||||||
infra)
|
infra)
|
||||||
case "$2" in
|
case "$2" in
|
||||||
up)
|
up)
|
||||||
log_info "Starting infrastructure services..."
|
log_info "Starting infrastructure services..."
|
||||||
docker compose up -d $CORE_SERVICES
|
docker compose up -d $CORE_SERVICES
|
||||||
log_success "Infrastructure started"
|
log_success "Infrastructure started"
|
||||||
;;
|
;;
|
||||||
down)
|
down)
|
||||||
log_info "Stopping infrastructure services..."
|
log_info "Stopping infrastructure services..."
|
||||||
docker compose stop $CORE_SERVICES
|
docker compose stop $CORE_SERVICES
|
||||||
log_success "Infrastructure stopped"
|
log_success "Infrastructure stopped"
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Usage: $0 infra {up|down}"
|
echo "Usage: $0 infra {up|down}"
|
||||||
exit 1
|
exit 1
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
;;
|
;;
|
||||||
|
|
||||||
mpc)
|
mpc)
|
||||||
case "$2" in
|
case "$2" in
|
||||||
up)
|
up)
|
||||||
log_info "Starting MPC services..."
|
log_info "Starting MPC services..."
|
||||||
docker compose up -d $MPC_SERVICES
|
docker compose up -d $MPC_SERVICES
|
||||||
log_success "MPC services started"
|
log_success "MPC services started"
|
||||||
;;
|
;;
|
||||||
down)
|
down)
|
||||||
log_info "Stopping MPC services..."
|
log_info "Stopping MPC services..."
|
||||||
docker compose stop $MPC_SERVICES
|
docker compose stop $MPC_SERVICES
|
||||||
log_success "MPC services stopped"
|
log_success "MPC services stopped"
|
||||||
;;
|
;;
|
||||||
restart)
|
restart)
|
||||||
log_info "Restarting MPC services..."
|
log_info "Restarting MPC services..."
|
||||||
docker compose stop $MPC_SERVICES
|
docker compose stop $MPC_SERVICES
|
||||||
docker compose up -d $MPC_SERVICES
|
docker compose up -d $MPC_SERVICES
|
||||||
log_success "MPC services restarted"
|
log_success "MPC services restarted"
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Usage: $0 mpc {up|down|restart}"
|
echo "Usage: $0 mpc {up|down|restart}"
|
||||||
exit 1
|
exit 1
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
;;
|
;;
|
||||||
|
|
||||||
clean)
|
clean)
|
||||||
log_warn "This will remove all containers and volumes!"
|
log_warn "This will remove all containers and volumes!"
|
||||||
read -p "Are you sure? (y/N) " -n 1 -r
|
read -p "Are you sure? (y/N) " -n 1 -r
|
||||||
echo
|
echo
|
||||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||||
docker compose down -v
|
docker compose down -v
|
||||||
log_success "MPC System cleaned"
|
log_success "MPC System cleaned"
|
||||||
else
|
else
|
||||||
log_info "Cancelled"
|
log_info "Cancelled"
|
||||||
fi
|
fi
|
||||||
;;
|
;;
|
||||||
|
|
||||||
shell)
|
shell)
|
||||||
if [ -n "$2" ]; then
|
if [ -n "$2" ]; then
|
||||||
log_info "Opening shell in $2..."
|
log_info "Opening shell in $2..."
|
||||||
docker compose exec "$2" sh
|
docker compose exec "$2" sh
|
||||||
else
|
else
|
||||||
log_info "Opening shell in account-service..."
|
log_info "Opening shell in account-service..."
|
||||||
docker compose exec account-service sh
|
docker compose exec account-service sh
|
||||||
fi
|
fi
|
||||||
;;
|
;;
|
||||||
|
|
||||||
test-api)
|
test-api)
|
||||||
log_info "Testing Account Service API..."
|
log_info "Testing Account Service API..."
|
||||||
echo ""
|
echo ""
|
||||||
echo "Health check:"
|
echo "Health check:"
|
||||||
curl -s "http://localhost:4000/health" | jq . 2>/dev/null || curl -s "http://localhost:4000/health"
|
curl -s "http://localhost:4000/health" | jq . 2>/dev/null || curl -s "http://localhost:4000/health"
|
||||||
echo ""
|
echo ""
|
||||||
;;
|
;;
|
||||||
|
|
||||||
*)
|
*)
|
||||||
echo "MPC System Deployment Script"
|
echo "MPC System Deployment Script"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Usage: $0 <command> [options]"
|
echo "Usage: $0 <command> [options]"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Commands:"
|
echo "Commands:"
|
||||||
echo " build - Build all Docker images"
|
echo " build - Build all Docker images"
|
||||||
echo " build-no-cache - Build images without cache"
|
echo " build-no-cache - Build images without cache"
|
||||||
echo " up|start - Start all services"
|
echo " up|start - Start all services"
|
||||||
echo " down|stop - Stop all services"
|
echo " down|stop - Stop all services"
|
||||||
echo " restart - Restart all services"
|
echo " restart - Restart all services"
|
||||||
echo " logs [service] - Follow logs (all or specific service)"
|
echo " logs [service] - Follow logs (all or specific service)"
|
||||||
echo " logs-tail [svc] - Show last 100 log lines"
|
echo " logs-tail [svc] - Show last 100 log lines"
|
||||||
echo " status|ps - Show services status"
|
echo " status|ps - Show services status"
|
||||||
echo " health - Check all services health"
|
echo " health - Check all services health"
|
||||||
echo ""
|
echo ""
|
||||||
echo " infra up|down - Start/stop infrastructure only"
|
echo " infra up|down - Start/stop infrastructure only"
|
||||||
echo " mpc up|down|restart - Start/stop/restart MPC services only"
|
echo " mpc up|down|restart - Start/stop/restart MPC services only"
|
||||||
echo ""
|
echo ""
|
||||||
echo " shell [service] - Open shell in container"
|
echo " shell [service] - Open shell in container"
|
||||||
echo " test-api - Test Account Service API"
|
echo " test-api - Test Account Service API"
|
||||||
echo " clean - Remove all containers and volumes"
|
echo " clean - Remove all containers and volumes"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Services:"
|
echo "Services:"
|
||||||
echo " Infrastructure: $CORE_SERVICES"
|
echo " Infrastructure: $CORE_SERVICES"
|
||||||
echo " MPC Services: $MPC_SERVICES"
|
echo " MPC Services: $MPC_SERVICES"
|
||||||
exit 1
|
exit 1
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
|
|
|
||||||
|
|
@ -1,391 +1,391 @@
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# MPC-System Docker Compose Configuration
|
# MPC-System Docker Compose Configuration
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Purpose: TSS (Threshold Signature Scheme) key generation and signing service
|
# Purpose: TSS (Threshold Signature Scheme) key generation and signing service
|
||||||
#
|
#
|
||||||
# Usage:
|
# Usage:
|
||||||
# Development: docker compose up -d
|
# Development: docker compose up -d
|
||||||
# Production: docker compose --env-file .env up -d
|
# Production: docker compose --env-file .env up -d
|
||||||
#
|
#
|
||||||
# External Ports:
|
# External Ports:
|
||||||
# 4000 - Account Service HTTP API (accessed by backend mpc-service)
|
# 4000 - Account Service HTTP API (accessed by backend mpc-service)
|
||||||
# 8081 - Session Coordinator API (accessed by backend mpc-service)
|
# 8081 - Session Coordinator API (accessed by backend mpc-service)
|
||||||
# 8082 - Message Router WebSocket (accessed by backend mpc-service)
|
# 8082 - Message Router WebSocket (accessed by backend mpc-service)
|
||||||
# 8083 - Server Party API (accessed by backend mpc-service for user share generation)
|
# 8083 - Server Party API (accessed by backend mpc-service for user share generation)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
services:
|
services:
|
||||||
# ============================================
|
# ============================================
|
||||||
# Infrastructure Services
|
# Infrastructure Services
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
# PostgreSQL Database
|
# PostgreSQL Database
|
||||||
postgres:
|
postgres:
|
||||||
image: postgres:15-alpine
|
image: postgres:15-alpine
|
||||||
container_name: mpc-postgres
|
container_name: mpc-postgres
|
||||||
environment:
|
environment:
|
||||||
POSTGRES_DB: mpc_system
|
POSTGRES_DB: mpc_system
|
||||||
POSTGRES_USER: ${POSTGRES_USER:-mpc_user}
|
POSTGRES_USER: ${POSTGRES_USER:-mpc_user}
|
||||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set in .env}
|
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set in .env}
|
||||||
volumes:
|
volumes:
|
||||||
- postgres-data:/var/lib/postgresql/data
|
- postgres-data:/var/lib/postgresql/data
|
||||||
- ./migrations:/docker-entrypoint-initdb.d:ro
|
- ./migrations:/docker-entrypoint-initdb.d:ro
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-mpc_user} -d mpc_system"]
|
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-mpc_user} -d mpc_system"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
timeout: 5s
|
timeout: 5s
|
||||||
retries: 5
|
retries: 5
|
||||||
start_period: 30s
|
start_period: 30s
|
||||||
networks:
|
networks:
|
||||||
- mpc-network
|
- mpc-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
# 生产环境不暴露端口到主机,仅内部网络可访问
|
# 生产环境不暴露端口到主机,仅内部网络可访问
|
||||||
# ports:
|
# ports:
|
||||||
# - "5432:5432"
|
# - "5432:5432"
|
||||||
|
|
||||||
# Redis Cache
|
# Redis Cache
|
||||||
redis:
|
redis:
|
||||||
image: redis:7-alpine
|
image: redis:7-alpine
|
||||||
container_name: mpc-redis
|
container_name: mpc-redis
|
||||||
command: redis-server --appendonly yes --maxmemory 512mb --maxmemory-policy allkeys-lru ${REDIS_PASSWORD:+--requirepass $REDIS_PASSWORD}
|
command: redis-server --appendonly yes --maxmemory 512mb --maxmemory-policy allkeys-lru ${REDIS_PASSWORD:+--requirepass $REDIS_PASSWORD}
|
||||||
volumes:
|
volumes:
|
||||||
- redis-data:/data
|
- redis-data:/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "redis-cli", "ping"]
|
test: ["CMD", "redis-cli", "ping"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
timeout: 5s
|
timeout: 5s
|
||||||
retries: 5
|
retries: 5
|
||||||
networks:
|
networks:
|
||||||
- mpc-network
|
- mpc-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
# RabbitMQ Message Broker
|
# RabbitMQ Message Broker
|
||||||
rabbitmq:
|
rabbitmq:
|
||||||
image: rabbitmq:3-management-alpine
|
image: rabbitmq:3-management-alpine
|
||||||
container_name: mpc-rabbitmq
|
container_name: mpc-rabbitmq
|
||||||
environment:
|
environment:
|
||||||
RABBITMQ_DEFAULT_USER: ${RABBITMQ_USER:-mpc_user}
|
RABBITMQ_DEFAULT_USER: ${RABBITMQ_USER:-mpc_user}
|
||||||
RABBITMQ_DEFAULT_PASS: ${RABBITMQ_PASSWORD:?RABBITMQ_PASSWORD must be set in .env}
|
RABBITMQ_DEFAULT_PASS: ${RABBITMQ_PASSWORD:?RABBITMQ_PASSWORD must be set in .env}
|
||||||
RABBITMQ_DEFAULT_VHOST: /
|
RABBITMQ_DEFAULT_VHOST: /
|
||||||
volumes:
|
volumes:
|
||||||
- rabbitmq-data:/var/lib/rabbitmq
|
- rabbitmq-data:/var/lib/rabbitmq
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "rabbitmq-diagnostics", "-q", "ping"]
|
test: ["CMD", "rabbitmq-diagnostics", "-q", "ping"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 5
|
retries: 5
|
||||||
start_period: 30s
|
start_period: 30s
|
||||||
networks:
|
networks:
|
||||||
- mpc-network
|
- mpc-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
# 生产环境管理界面仅开发时使用
|
# 生产环境管理界面仅开发时使用
|
||||||
# ports:
|
# ports:
|
||||||
# - "15672:15672"
|
# - "15672:15672"
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# MPC Core Services
|
# MPC Core Services
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
# Session Coordinator Service - 会话协调器
|
# Session Coordinator Service - 会话协调器
|
||||||
session-coordinator:
|
session-coordinator:
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
dockerfile: services/session-coordinator/Dockerfile
|
dockerfile: services/session-coordinator/Dockerfile
|
||||||
container_name: mpc-session-coordinator
|
container_name: mpc-session-coordinator
|
||||||
ports:
|
ports:
|
||||||
- "8081:8080" # HTTP API for external access
|
- "8081:8080" # HTTP API for external access
|
||||||
environment:
|
environment:
|
||||||
MPC_SERVER_GRPC_PORT: 50051
|
MPC_SERVER_GRPC_PORT: 50051
|
||||||
MPC_SERVER_HTTP_PORT: 8080
|
MPC_SERVER_HTTP_PORT: 8080
|
||||||
MPC_SERVER_ENVIRONMENT: ${ENVIRONMENT:-production}
|
MPC_SERVER_ENVIRONMENT: ${ENVIRONMENT:-production}
|
||||||
MPC_DATABASE_HOST: postgres
|
MPC_DATABASE_HOST: postgres
|
||||||
MPC_DATABASE_PORT: 5432
|
MPC_DATABASE_PORT: 5432
|
||||||
MPC_DATABASE_USER: ${POSTGRES_USER:-mpc_user}
|
MPC_DATABASE_USER: ${POSTGRES_USER:-mpc_user}
|
||||||
MPC_DATABASE_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}
|
MPC_DATABASE_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}
|
||||||
MPC_DATABASE_DBNAME: mpc_system
|
MPC_DATABASE_DBNAME: mpc_system
|
||||||
MPC_DATABASE_SSLMODE: disable
|
MPC_DATABASE_SSLMODE: disable
|
||||||
MPC_REDIS_HOST: redis
|
MPC_REDIS_HOST: redis
|
||||||
MPC_REDIS_PORT: 6379
|
MPC_REDIS_PORT: 6379
|
||||||
MPC_REDIS_PASSWORD: ${REDIS_PASSWORD:-}
|
MPC_REDIS_PASSWORD: ${REDIS_PASSWORD:-}
|
||||||
MPC_RABBITMQ_HOST: rabbitmq
|
MPC_RABBITMQ_HOST: rabbitmq
|
||||||
MPC_RABBITMQ_PORT: 5672
|
MPC_RABBITMQ_PORT: 5672
|
||||||
MPC_RABBITMQ_USER: ${RABBITMQ_USER:-mpc_user}
|
MPC_RABBITMQ_USER: ${RABBITMQ_USER:-mpc_user}
|
||||||
MPC_RABBITMQ_PASSWORD: ${RABBITMQ_PASSWORD:?RABBITMQ_PASSWORD must be set}
|
MPC_RABBITMQ_PASSWORD: ${RABBITMQ_PASSWORD:?RABBITMQ_PASSWORD must be set}
|
||||||
MPC_JWT_SECRET_KEY: ${JWT_SECRET_KEY}
|
MPC_JWT_SECRET_KEY: ${JWT_SECRET_KEY}
|
||||||
MPC_JWT_ISSUER: mpc-system
|
MPC_JWT_ISSUER: mpc-system
|
||||||
depends_on:
|
depends_on:
|
||||||
postgres:
|
postgres:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
rabbitmq:
|
rabbitmq:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "curl", "-sf", "http://localhost:8080/health"]
|
test: ["CMD", "curl", "-sf", "http://localhost:8080/health"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
start_period: 30s
|
start_period: 30s
|
||||||
networks:
|
networks:
|
||||||
- mpc-network
|
- mpc-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
# Message Router Service - 消息路由
|
# Message Router Service - 消息路由
|
||||||
message-router:
|
message-router:
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
dockerfile: services/message-router/Dockerfile
|
dockerfile: services/message-router/Dockerfile
|
||||||
container_name: mpc-message-router
|
container_name: mpc-message-router
|
||||||
ports:
|
ports:
|
||||||
- "8082:8080" # WebSocket for external connections
|
- "8082:8080" # WebSocket for external connections
|
||||||
environment:
|
environment:
|
||||||
MPC_SERVER_GRPC_PORT: 50051
|
MPC_SERVER_GRPC_PORT: 50051
|
||||||
MPC_SERVER_HTTP_PORT: 8080
|
MPC_SERVER_HTTP_PORT: 8080
|
||||||
MPC_SERVER_ENVIRONMENT: ${ENVIRONMENT:-production}
|
MPC_SERVER_ENVIRONMENT: ${ENVIRONMENT:-production}
|
||||||
MPC_DATABASE_HOST: postgres
|
MPC_DATABASE_HOST: postgres
|
||||||
MPC_DATABASE_PORT: 5432
|
MPC_DATABASE_PORT: 5432
|
||||||
MPC_DATABASE_USER: ${POSTGRES_USER:-mpc_user}
|
MPC_DATABASE_USER: ${POSTGRES_USER:-mpc_user}
|
||||||
MPC_DATABASE_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}
|
MPC_DATABASE_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}
|
||||||
MPC_DATABASE_DBNAME: mpc_system
|
MPC_DATABASE_DBNAME: mpc_system
|
||||||
MPC_DATABASE_SSLMODE: disable
|
MPC_DATABASE_SSLMODE: disable
|
||||||
MPC_RABBITMQ_HOST: rabbitmq
|
MPC_RABBITMQ_HOST: rabbitmq
|
||||||
MPC_RABBITMQ_PORT: 5672
|
MPC_RABBITMQ_PORT: 5672
|
||||||
MPC_RABBITMQ_USER: ${RABBITMQ_USER:-mpc_user}
|
MPC_RABBITMQ_USER: ${RABBITMQ_USER:-mpc_user}
|
||||||
MPC_RABBITMQ_PASSWORD: ${RABBITMQ_PASSWORD:?RABBITMQ_PASSWORD must be set}
|
MPC_RABBITMQ_PASSWORD: ${RABBITMQ_PASSWORD:?RABBITMQ_PASSWORD must be set}
|
||||||
depends_on:
|
depends_on:
|
||||||
postgres:
|
postgres:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
rabbitmq:
|
rabbitmq:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "curl", "-sf", "http://localhost:8080/health"]
|
test: ["CMD", "curl", "-sf", "http://localhost:8080/health"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
start_period: 30s
|
start_period: 30s
|
||||||
networks:
|
networks:
|
||||||
- mpc-network
|
- mpc-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Server Party Services - TSS 参与方
|
# Server Party Services - TSS 参与方
|
||||||
# 2-of-3 阈值签名: 至少 2 个 party 参与才能完成签名
|
# 2-of-3 阈值签名: 至少 2 个 party 参与才能完成签名
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
# Server Party 1
|
# Server Party 1
|
||||||
server-party-1:
|
server-party-1:
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
dockerfile: services/server-party/Dockerfile
|
dockerfile: services/server-party/Dockerfile
|
||||||
container_name: mpc-server-party-1
|
container_name: mpc-server-party-1
|
||||||
environment:
|
environment:
|
||||||
MPC_SERVER_GRPC_PORT: 50051
|
MPC_SERVER_GRPC_PORT: 50051
|
||||||
MPC_SERVER_HTTP_PORT: 8080
|
MPC_SERVER_HTTP_PORT: 8080
|
||||||
MPC_SERVER_ENVIRONMENT: ${ENVIRONMENT:-production}
|
MPC_SERVER_ENVIRONMENT: ${ENVIRONMENT:-production}
|
||||||
MPC_DATABASE_HOST: postgres
|
MPC_DATABASE_HOST: postgres
|
||||||
MPC_DATABASE_PORT: 5432
|
MPC_DATABASE_PORT: 5432
|
||||||
MPC_DATABASE_USER: ${POSTGRES_USER:-mpc_user}
|
MPC_DATABASE_USER: ${POSTGRES_USER:-mpc_user}
|
||||||
MPC_DATABASE_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}
|
MPC_DATABASE_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}
|
||||||
MPC_DATABASE_DBNAME: mpc_system
|
MPC_DATABASE_DBNAME: mpc_system
|
||||||
MPC_DATABASE_SSLMODE: disable
|
MPC_DATABASE_SSLMODE: disable
|
||||||
SESSION_COORDINATOR_ADDR: session-coordinator:50051
|
SESSION_COORDINATOR_ADDR: session-coordinator:50051
|
||||||
MESSAGE_ROUTER_ADDR: message-router:50051
|
MESSAGE_ROUTER_ADDR: message-router:50051
|
||||||
MPC_CRYPTO_MASTER_KEY: ${CRYPTO_MASTER_KEY}
|
MPC_CRYPTO_MASTER_KEY: ${CRYPTO_MASTER_KEY}
|
||||||
PARTY_ID: server-party-1
|
PARTY_ID: server-party-1
|
||||||
depends_on:
|
depends_on:
|
||||||
postgres:
|
postgres:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
session-coordinator:
|
session-coordinator:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
message-router:
|
message-router:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "curl", "-sf", "http://localhost:8080/health"]
|
test: ["CMD", "curl", "-sf", "http://localhost:8080/health"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
start_period: 30s
|
start_period: 30s
|
||||||
networks:
|
networks:
|
||||||
- mpc-network
|
- mpc-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
# Server Party 2
|
# Server Party 2
|
||||||
server-party-2:
|
server-party-2:
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
dockerfile: services/server-party/Dockerfile
|
dockerfile: services/server-party/Dockerfile
|
||||||
container_name: mpc-server-party-2
|
container_name: mpc-server-party-2
|
||||||
environment:
|
environment:
|
||||||
MPC_SERVER_GRPC_PORT: 50051
|
MPC_SERVER_GRPC_PORT: 50051
|
||||||
MPC_SERVER_HTTP_PORT: 8080
|
MPC_SERVER_HTTP_PORT: 8080
|
||||||
MPC_SERVER_ENVIRONMENT: ${ENVIRONMENT:-production}
|
MPC_SERVER_ENVIRONMENT: ${ENVIRONMENT:-production}
|
||||||
MPC_DATABASE_HOST: postgres
|
MPC_DATABASE_HOST: postgres
|
||||||
MPC_DATABASE_PORT: 5432
|
MPC_DATABASE_PORT: 5432
|
||||||
MPC_DATABASE_USER: ${POSTGRES_USER:-mpc_user}
|
MPC_DATABASE_USER: ${POSTGRES_USER:-mpc_user}
|
||||||
MPC_DATABASE_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}
|
MPC_DATABASE_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}
|
||||||
MPC_DATABASE_DBNAME: mpc_system
|
MPC_DATABASE_DBNAME: mpc_system
|
||||||
MPC_DATABASE_SSLMODE: disable
|
MPC_DATABASE_SSLMODE: disable
|
||||||
SESSION_COORDINATOR_ADDR: session-coordinator:50051
|
SESSION_COORDINATOR_ADDR: session-coordinator:50051
|
||||||
MESSAGE_ROUTER_ADDR: message-router:50051
|
MESSAGE_ROUTER_ADDR: message-router:50051
|
||||||
MPC_CRYPTO_MASTER_KEY: ${CRYPTO_MASTER_KEY}
|
MPC_CRYPTO_MASTER_KEY: ${CRYPTO_MASTER_KEY}
|
||||||
PARTY_ID: server-party-2
|
PARTY_ID: server-party-2
|
||||||
depends_on:
|
depends_on:
|
||||||
postgres:
|
postgres:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
session-coordinator:
|
session-coordinator:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
message-router:
|
message-router:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "curl", "-sf", "http://localhost:8080/health"]
|
test: ["CMD", "curl", "-sf", "http://localhost:8080/health"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
start_period: 30s
|
start_period: 30s
|
||||||
networks:
|
networks:
|
||||||
- mpc-network
|
- mpc-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
# Server Party 3
|
# Server Party 3
|
||||||
server-party-3:
|
server-party-3:
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
dockerfile: services/server-party/Dockerfile
|
dockerfile: services/server-party/Dockerfile
|
||||||
container_name: mpc-server-party-3
|
container_name: mpc-server-party-3
|
||||||
environment:
|
environment:
|
||||||
MPC_SERVER_GRPC_PORT: 50051
|
MPC_SERVER_GRPC_PORT: 50051
|
||||||
MPC_SERVER_HTTP_PORT: 8080
|
MPC_SERVER_HTTP_PORT: 8080
|
||||||
MPC_SERVER_ENVIRONMENT: ${ENVIRONMENT:-production}
|
MPC_SERVER_ENVIRONMENT: ${ENVIRONMENT:-production}
|
||||||
MPC_DATABASE_HOST: postgres
|
MPC_DATABASE_HOST: postgres
|
||||||
MPC_DATABASE_PORT: 5432
|
MPC_DATABASE_PORT: 5432
|
||||||
MPC_DATABASE_USER: ${POSTGRES_USER:-mpc_user}
|
MPC_DATABASE_USER: ${POSTGRES_USER:-mpc_user}
|
||||||
MPC_DATABASE_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}
|
MPC_DATABASE_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}
|
||||||
MPC_DATABASE_DBNAME: mpc_system
|
MPC_DATABASE_DBNAME: mpc_system
|
||||||
MPC_DATABASE_SSLMODE: disable
|
MPC_DATABASE_SSLMODE: disable
|
||||||
SESSION_COORDINATOR_ADDR: session-coordinator:50051
|
SESSION_COORDINATOR_ADDR: session-coordinator:50051
|
||||||
MESSAGE_ROUTER_ADDR: message-router:50051
|
MESSAGE_ROUTER_ADDR: message-router:50051
|
||||||
MPC_CRYPTO_MASTER_KEY: ${CRYPTO_MASTER_KEY}
|
MPC_CRYPTO_MASTER_KEY: ${CRYPTO_MASTER_KEY}
|
||||||
PARTY_ID: server-party-3
|
PARTY_ID: server-party-3
|
||||||
depends_on:
|
depends_on:
|
||||||
postgres:
|
postgres:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
session-coordinator:
|
session-coordinator:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
message-router:
|
message-router:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "curl", "-sf", "http://localhost:8080/health"]
|
test: ["CMD", "curl", "-sf", "http://localhost:8080/health"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
start_period: 30s
|
start_period: 30s
|
||||||
networks:
|
networks:
|
||||||
- mpc-network
|
- mpc-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Server Party API - User Share Generation Service
|
# Server Party API - User Share Generation Service
|
||||||
# Unlike other server-party services, this one returns shares to the caller
|
# Unlike other server-party services, this one returns shares to the caller
|
||||||
# instead of storing them internally
|
# instead of storing them internally
|
||||||
# ============================================
|
# ============================================
|
||||||
server-party-api:
|
server-party-api:
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
dockerfile: services/server-party-api/Dockerfile
|
dockerfile: services/server-party-api/Dockerfile
|
||||||
container_name: mpc-server-party-api
|
container_name: mpc-server-party-api
|
||||||
ports:
|
ports:
|
||||||
- "8083:8080" # HTTP API for user share generation
|
- "8083:8080" # HTTP API for user share generation
|
||||||
environment:
|
environment:
|
||||||
MPC_SERVER_HTTP_PORT: 8080
|
MPC_SERVER_HTTP_PORT: 8080
|
||||||
MPC_SERVER_ENVIRONMENT: ${ENVIRONMENT:-production}
|
MPC_SERVER_ENVIRONMENT: ${ENVIRONMENT:-production}
|
||||||
SESSION_COORDINATOR_ADDR: session-coordinator:50051
|
SESSION_COORDINATOR_ADDR: session-coordinator:50051
|
||||||
MESSAGE_ROUTER_ADDR: message-router:50051
|
MESSAGE_ROUTER_ADDR: message-router:50051
|
||||||
MPC_CRYPTO_MASTER_KEY: ${CRYPTO_MASTER_KEY}
|
MPC_CRYPTO_MASTER_KEY: ${CRYPTO_MASTER_KEY}
|
||||||
# API 认证密钥 (与 mpc-service 配置的 MPC_API_KEY 一致)
|
# API 认证密钥 (与 mpc-service 配置的 MPC_API_KEY 一致)
|
||||||
MPC_API_KEY: ${MPC_API_KEY}
|
MPC_API_KEY: ${MPC_API_KEY}
|
||||||
depends_on:
|
depends_on:
|
||||||
session-coordinator:
|
session-coordinator:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
message-router:
|
message-router:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "curl", "-sf", "http://localhost:8080/health"]
|
test: ["CMD", "curl", "-sf", "http://localhost:8080/health"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
start_period: 30s
|
start_period: 30s
|
||||||
networks:
|
networks:
|
||||||
- mpc-network
|
- mpc-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Account Service - External API Entry Point
|
# Account Service - External API Entry Point
|
||||||
# Main HTTP API for backend mpc-service integration
|
# Main HTTP API for backend mpc-service integration
|
||||||
# ============================================
|
# ============================================
|
||||||
account-service:
|
account-service:
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
dockerfile: services/account/Dockerfile
|
dockerfile: services/account/Dockerfile
|
||||||
container_name: mpc-account-service
|
container_name: mpc-account-service
|
||||||
ports:
|
ports:
|
||||||
- "4000:8080" # HTTP API for external access
|
- "4000:8080" # HTTP API for external access
|
||||||
environment:
|
environment:
|
||||||
MPC_SERVER_GRPC_PORT: 50051
|
MPC_SERVER_GRPC_PORT: 50051
|
||||||
MPC_SERVER_HTTP_PORT: 8080
|
MPC_SERVER_HTTP_PORT: 8080
|
||||||
MPC_SERVER_ENVIRONMENT: ${ENVIRONMENT:-production}
|
MPC_SERVER_ENVIRONMENT: ${ENVIRONMENT:-production}
|
||||||
MPC_DATABASE_HOST: postgres
|
MPC_DATABASE_HOST: postgres
|
||||||
MPC_DATABASE_PORT: 5432
|
MPC_DATABASE_PORT: 5432
|
||||||
MPC_DATABASE_USER: ${POSTGRES_USER:-mpc_user}
|
MPC_DATABASE_USER: ${POSTGRES_USER:-mpc_user}
|
||||||
MPC_DATABASE_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}
|
MPC_DATABASE_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD must be set}
|
||||||
MPC_DATABASE_DBNAME: mpc_system
|
MPC_DATABASE_DBNAME: mpc_system
|
||||||
MPC_DATABASE_SSLMODE: disable
|
MPC_DATABASE_SSLMODE: disable
|
||||||
MPC_REDIS_HOST: redis
|
MPC_REDIS_HOST: redis
|
||||||
MPC_REDIS_PORT: 6379
|
MPC_REDIS_PORT: 6379
|
||||||
MPC_REDIS_PASSWORD: ${REDIS_PASSWORD:-}
|
MPC_REDIS_PASSWORD: ${REDIS_PASSWORD:-}
|
||||||
MPC_RABBITMQ_HOST: rabbitmq
|
MPC_RABBITMQ_HOST: rabbitmq
|
||||||
MPC_RABBITMQ_PORT: 5672
|
MPC_RABBITMQ_PORT: 5672
|
||||||
MPC_RABBITMQ_USER: ${RABBITMQ_USER:-mpc_user}
|
MPC_RABBITMQ_USER: ${RABBITMQ_USER:-mpc_user}
|
||||||
MPC_RABBITMQ_PASSWORD: ${RABBITMQ_PASSWORD:?RABBITMQ_PASSWORD must be set}
|
MPC_RABBITMQ_PASSWORD: ${RABBITMQ_PASSWORD:?RABBITMQ_PASSWORD must be set}
|
||||||
MPC_COORDINATOR_URL: session-coordinator:50051
|
MPC_COORDINATOR_URL: session-coordinator:50051
|
||||||
MPC_JWT_SECRET_KEY: ${JWT_SECRET_KEY}
|
MPC_JWT_SECRET_KEY: ${JWT_SECRET_KEY}
|
||||||
# API 认证密钥 (与 mpc-service 配置的 MPC_API_KEY 一致)
|
# API 认证密钥 (与 mpc-service 配置的 MPC_API_KEY 一致)
|
||||||
MPC_API_KEY: ${MPC_API_KEY}
|
MPC_API_KEY: ${MPC_API_KEY}
|
||||||
# Allowed source IPs (backend servers)
|
# Allowed source IPs (backend servers)
|
||||||
# Empty default = allow all (protected by API_KEY). Set in .env for production!
|
# Empty default = allow all (protected by API_KEY). Set in .env for production!
|
||||||
ALLOWED_IPS: ${ALLOWED_IPS:-}
|
ALLOWED_IPS: ${ALLOWED_IPS:-}
|
||||||
depends_on:
|
depends_on:
|
||||||
postgres:
|
postgres:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
rabbitmq:
|
rabbitmq:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
session-coordinator:
|
session-coordinator:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "curl", "-sf", "http://localhost:8080/health"]
|
test: ["CMD", "curl", "-sf", "http://localhost:8080/health"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
start_period: 30s
|
start_period: 30s
|
||||||
networks:
|
networks:
|
||||||
- mpc-network
|
- mpc-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Networks
|
# Networks
|
||||||
# ============================================
|
# ============================================
|
||||||
networks:
|
networks:
|
||||||
mpc-network:
|
mpc-network:
|
||||||
driver: bridge
|
driver: bridge
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Volumes - 持久化存储
|
# Volumes - 持久化存储
|
||||||
# ============================================
|
# ============================================
|
||||||
volumes:
|
volumes:
|
||||||
postgres-data:
|
postgres-data:
|
||||||
driver: local
|
driver: local
|
||||||
redis-data:
|
redis-data:
|
||||||
driver: local
|
driver: local
|
||||||
rabbitmq-data:
|
rabbitmq-data:
|
||||||
driver: local
|
driver: local
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,453 +1,453 @@
|
||||||
# MPC 分布式签名系统 - TSS 协议详解
|
# MPC 分布式签名系统 - TSS 协议详解
|
||||||
|
|
||||||
## 1. 概述
|
## 1. 概述
|
||||||
|
|
||||||
本系统使用 **门限签名方案 (Threshold Signature Scheme, TSS)** 实现分布式密钥管理和签名。基于 [bnb-chain/tss-lib](https://github.com/bnb-chain/tss-lib) 库,采用 GG20 协议。
|
本系统使用 **门限签名方案 (Threshold Signature Scheme, TSS)** 实现分布式密钥管理和签名。基于 [bnb-chain/tss-lib](https://github.com/bnb-chain/tss-lib) 库,采用 GG20 协议。
|
||||||
|
|
||||||
### 1.1 核心概念
|
### 1.1 核心概念
|
||||||
|
|
||||||
| 术语 | 定义 |
|
| 术语 | 定义 |
|
||||||
|------|------|
|
|------|------|
|
||||||
| t-of-n | t+1 个参与方中的任意组合可以签名,需要 n 个参与方共同生成密钥 |
|
| t-of-n | t+1 个参与方中的任意组合可以签名,需要 n 个参与方共同生成密钥 |
|
||||||
| DKG | 分布式密钥生成 (Distributed Key Generation) |
|
| DKG | 分布式密钥生成 (Distributed Key Generation) |
|
||||||
| TSS | 门限签名方案 (Threshold Signature Scheme) |
|
| TSS | 门限签名方案 (Threshold Signature Scheme) |
|
||||||
| Party | MPC 协议中的参与方 |
|
| Party | MPC 协议中的参与方 |
|
||||||
| Share | 密钥分片,每个 Party 持有一份 |
|
| Share | 密钥分片,每个 Party 持有一份 |
|
||||||
|
|
||||||
### 1.2 安全属性
|
### 1.2 安全属性
|
||||||
|
|
||||||
- **无单点故障**: 私钥从未以完整形式存在
|
- **无单点故障**: 私钥从未以完整形式存在
|
||||||
- **门限安全**: 需要 t+1 个分片才能签名
|
- **门限安全**: 需要 t+1 个分片才能签名
|
||||||
- **抗合谋**: t 个恶意方无法伪造签名
|
- **抗合谋**: t 个恶意方无法伪造签名
|
||||||
- **可审计**: 每次签名可追踪参与方
|
- **可审计**: 每次签名可追踪参与方
|
||||||
|
|
||||||
## 2. 阈值参数说明
|
## 2. 阈值参数说明
|
||||||
|
|
||||||
### 2.1 tss-lib 参数约定
|
### 2.1 tss-lib 参数约定
|
||||||
|
|
||||||
在 tss-lib 中,`threshold` 参数定义如下:
|
在 tss-lib 中,`threshold` 参数定义如下:
|
||||||
- `threshold = t` 表示需要 **t+1** 个签名者
|
- `threshold = t` 表示需要 **t+1** 个签名者
|
||||||
- 例如: `threshold=1` 需要 2 个签名者
|
- 例如: `threshold=1` 需要 2 个签名者
|
||||||
|
|
||||||
### 2.2 常见阈值方案
|
### 2.2 常见阈值方案
|
||||||
|
|
||||||
| 方案 | tss-lib threshold | 总参与方 (n) | 签名者数 (t+1) | 应用场景 |
|
| 方案 | tss-lib threshold | 总参与方 (n) | 签名者数 (t+1) | 应用场景 |
|
||||||
|------|-------------------|-------------|---------------|---------|
|
|------|-------------------|-------------|---------------|---------|
|
||||||
| 2-of-3 | 1 | 3 | 2 | 个人钱包 + 设备 + 恢复 |
|
| 2-of-3 | 1 | 3 | 2 | 个人钱包 + 设备 + 恢复 |
|
||||||
| 3-of-5 | 2 | 5 | 3 | 企业多签 |
|
| 3-of-5 | 2 | 5 | 3 | 企业多签 |
|
||||||
| 4-of-7 | 3 | 7 | 4 | 机构托管 |
|
| 4-of-7 | 3 | 7 | 4 | 机构托管 |
|
||||||
| 5-of-9 | 4 | 9 | 5 | 大型组织 |
|
| 5-of-9 | 4 | 9 | 5 | 大型组织 |
|
||||||
|
|
||||||
### 2.3 阈值选择建议
|
### 2.3 阈值选择建议
|
||||||
|
|
||||||
```
|
```
|
||||||
安全性 vs 可用性权衡:
|
安全性 vs 可用性权衡:
|
||||||
|
|
||||||
高安全性 ◄────────────────────────► 高可用性
|
高安全性 ◄────────────────────────► 高可用性
|
||||||
5-of-9 4-of-7 3-of-5 2-of-3
|
5-of-9 4-of-7 3-of-5 2-of-3
|
||||||
|
|
||||||
建议:
|
建议:
|
||||||
- 个人用户: 2-of-3 (设备 + 服务器 + 恢复)
|
- 个人用户: 2-of-3 (设备 + 服务器 + 恢复)
|
||||||
- 小型企业: 3-of-5 (3 管理员 + 1 服务器 + 1 恢复)
|
- 小型企业: 3-of-5 (3 管理员 + 1 服务器 + 1 恢复)
|
||||||
- 大型企业: 4-of-7 或更高
|
- 大型企业: 4-of-7 或更高
|
||||||
```
|
```
|
||||||
|
|
||||||
## 3. 密钥生成协议 (Keygen)
|
## 3. 密钥生成协议 (Keygen)
|
||||||
|
|
||||||
### 3.1 协议流程
|
### 3.1 协议流程
|
||||||
|
|
||||||
```
|
```
|
||||||
Round 1: 承诺分发
|
Round 1: 承诺分发
|
||||||
┌────────────┐ ┌────────────┐ ┌────────────┐
|
┌────────────┐ ┌────────────┐ ┌────────────┐
|
||||||
│ Party 0 │ │ Party 1 │ │ Party 2 │
|
│ Party 0 │ │ Party 1 │ │ Party 2 │
|
||||||
└─────┬──────┘ └─────┬──────┘ └─────┬──────┘
|
└─────┬──────┘ └─────┬──────┘ └─────┬──────┘
|
||||||
│ │ │
|
│ │ │
|
||||||
│ 生成随机多项式 │ │
|
│ 生成随机多项式 │ │
|
||||||
│ 计算承诺 Ci │ │
|
│ 计算承诺 Ci │ │
|
||||||
│ │ │
|
│ │ │
|
||||||
│◄─────────────────┼──────────────────┤ 广播承诺
|
│◄─────────────────┼──────────────────┤ 广播承诺
|
||||||
├──────────────────►◄─────────────────┤
|
├──────────────────►◄─────────────────┤
|
||||||
│ │ │
|
│ │ │
|
||||||
|
|
||||||
Round 2: 秘密分享
|
Round 2: 秘密分享
|
||||||
│ │ │
|
│ │ │
|
||||||
│ 计算 Shamir 分片│ │
|
│ 计算 Shamir 分片│ │
|
||||||
│ 发送 share_ij │ │
|
│ 发送 share_ij │ │
|
||||||
│ │ │
|
│ │ │
|
||||||
│──────────────────► │ 点对点
|
│──────────────────► │ 点对点
|
||||||
│ ◄──────────────────│
|
│ ◄──────────────────│
|
||||||
◄──────────────────│ │
|
◄──────────────────│ │
|
||||||
│ │──────────────────►
|
│ │──────────────────►
|
||||||
│ │ │
|
│ │ │
|
||||||
|
|
||||||
Round 3: 验证与聚合
|
Round 3: 验证与聚合
|
||||||
│ │ │
|
│ │ │
|
||||||
│ 验证收到的分片 │ │
|
│ 验证收到的分片 │ │
|
||||||
│ 计算最终密钥分片 │ │
|
│ 计算最终密钥分片 │ │
|
||||||
│ 计算公钥 PK │ │
|
│ 计算公钥 PK │ │
|
||||||
│ │ │
|
│ │ │
|
||||||
▼ ▼ ▼
|
▼ ▼ ▼
|
||||||
Share_0 Share_1 Share_2
|
Share_0 Share_1 Share_2
|
||||||
│ │ │
|
│ │ │
|
||||||
└──────────────────┼──────────────────┘
|
└──────────────────┼──────────────────┘
|
||||||
│
|
│
|
||||||
公钥 PK (相同)
|
公钥 PK (相同)
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3.2 代码实现
|
### 3.2 代码实现
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// pkg/tss/keygen.go
|
// pkg/tss/keygen.go
|
||||||
func RunLocalKeygen(threshold, totalParties int) ([]*LocalKeygenResult, error) {
|
func RunLocalKeygen(threshold, totalParties int) ([]*LocalKeygenResult, error) {
|
||||||
// 验证参数
|
// 验证参数
|
||||||
if threshold < 1 || threshold > totalParties {
|
if threshold < 1 || threshold > totalParties {
|
||||||
return nil, ErrInvalidThreshold
|
return nil, ErrInvalidThreshold
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建 Party IDs
|
// 创建 Party IDs
|
||||||
partyIDs := make([]*tss.PartyID, totalParties)
|
partyIDs := make([]*tss.PartyID, totalParties)
|
||||||
for i := 0; i < totalParties; i++ {
|
for i := 0; i < totalParties; i++ {
|
||||||
partyIDs[i] = tss.NewPartyID(
|
partyIDs[i] = tss.NewPartyID(
|
||||||
fmt.Sprintf("party-%d", i),
|
fmt.Sprintf("party-%d", i),
|
||||||
fmt.Sprintf("party-%d", i),
|
fmt.Sprintf("party-%d", i),
|
||||||
big.NewInt(int64(i+1)),
|
big.NewInt(int64(i+1)),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
sortedPartyIDs := tss.SortPartyIDs(partyIDs)
|
sortedPartyIDs := tss.SortPartyIDs(partyIDs)
|
||||||
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
||||||
|
|
||||||
// 创建各方的通道和 Party 实例
|
// 创建各方的通道和 Party 实例
|
||||||
outChs := make([]chan tss.Message, totalParties)
|
outChs := make([]chan tss.Message, totalParties)
|
||||||
endChs := make([]chan *keygen.LocalPartySaveData, totalParties)
|
endChs := make([]chan *keygen.LocalPartySaveData, totalParties)
|
||||||
parties := make([]tss.Party, totalParties)
|
parties := make([]tss.Party, totalParties)
|
||||||
|
|
||||||
for i := 0; i < totalParties; i++ {
|
for i := 0; i < totalParties; i++ {
|
||||||
outChs[i] = make(chan tss.Message, totalParties*10)
|
outChs[i] = make(chan tss.Message, totalParties*10)
|
||||||
endChs[i] = make(chan *keygen.LocalPartySaveData, 1)
|
endChs[i] = make(chan *keygen.LocalPartySaveData, 1)
|
||||||
params := tss.NewParameters(
|
params := tss.NewParameters(
|
||||||
tss.S256(), // secp256k1 曲线
|
tss.S256(), // secp256k1 曲线
|
||||||
peerCtx,
|
peerCtx,
|
||||||
sortedPartyIDs[i],
|
sortedPartyIDs[i],
|
||||||
totalParties,
|
totalParties,
|
||||||
threshold,
|
threshold,
|
||||||
)
|
)
|
||||||
parties[i] = keygen.NewLocalParty(params, outChs[i], endChs[i])
|
parties[i] = keygen.NewLocalParty(params, outChs[i], endChs[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
// 启动所有 Party
|
// 启动所有 Party
|
||||||
for i := 0; i < totalParties; i++ {
|
for i := 0; i < totalParties; i++ {
|
||||||
go parties[i].Start()
|
go parties[i].Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 消息路由
|
// 消息路由
|
||||||
go routeMessages(parties, outChs, sortedPartyIDs)
|
go routeMessages(parties, outChs, sortedPartyIDs)
|
||||||
|
|
||||||
// 收集结果
|
// 收集结果
|
||||||
results := make([]*LocalKeygenResult, totalParties)
|
results := make([]*LocalKeygenResult, totalParties)
|
||||||
for i := 0; i < totalParties; i++ {
|
for i := 0; i < totalParties; i++ {
|
||||||
saveData := <-endChs[i]
|
saveData := <-endChs[i]
|
||||||
results[i] = &LocalKeygenResult{
|
results[i] = &LocalKeygenResult{
|
||||||
SaveData: saveData,
|
SaveData: saveData,
|
||||||
PublicKey: saveData.ECDSAPub.ToECDSAPubKey(),
|
PublicKey: saveData.ECDSAPub.ToECDSAPubKey(),
|
||||||
PartyIndex: i,
|
PartyIndex: i,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3.3 SaveData 结构
|
### 3.3 SaveData 结构
|
||||||
|
|
||||||
每个 Party 保存的数据:
|
每个 Party 保存的数据:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
type LocalPartySaveData struct {
|
type LocalPartySaveData struct {
|
||||||
// 本方的私钥分片 (xi)
|
// 本方的私钥分片 (xi)
|
||||||
Xi *big.Int
|
Xi *big.Int
|
||||||
|
|
||||||
// 所有方的公钥分片 (Xi = xi * G)
|
// 所有方的公钥分片 (Xi = xi * G)
|
||||||
BigXj []*crypto.ECPoint
|
BigXj []*crypto.ECPoint
|
||||||
|
|
||||||
// 组公钥
|
// 组公钥
|
||||||
ECDSAPub *crypto.ECPoint
|
ECDSAPub *crypto.ECPoint
|
||||||
|
|
||||||
// Paillier 密钥对 (用于同态加密)
|
// Paillier 密钥对 (用于同态加密)
|
||||||
PaillierSK *paillier.PrivateKey
|
PaillierSK *paillier.PrivateKey
|
||||||
PaillierPKs []*paillier.PublicKey
|
PaillierPKs []*paillier.PublicKey
|
||||||
|
|
||||||
// 其他预计算数据...
|
// 其他预计算数据...
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## 4. 签名协议 (Signing)
|
## 4. 签名协议 (Signing)
|
||||||
|
|
||||||
### 4.1 协议流程
|
### 4.1 协议流程
|
||||||
|
|
||||||
```
|
```
|
||||||
签名协议 (GG20 - 6 轮):
|
签名协议 (GG20 - 6 轮):
|
||||||
|
|
||||||
Round 1: 承诺生成
|
Round 1: 承诺生成
|
||||||
┌────────────┐ ┌────────────┐
|
┌────────────┐ ┌────────────┐
|
||||||
│ Party 0 │ │ Party 1 │
|
│ Party 0 │ │ Party 1 │
|
||||||
└─────┬──────┘ └─────┬──────┘
|
└─────┬──────┘ └─────┬──────┘
|
||||||
│ │
|
│ │
|
||||||
│ 生成随机 ki │
|
│ 生成随机 ki │
|
||||||
│ 计算 γi = ki*G │
|
│ 计算 γi = ki*G │
|
||||||
│ 广播 C(γi) │
|
│ 广播 C(γi) │
|
||||||
│ │
|
│ │
|
||||||
│◄────────────────►│
|
│◄────────────────►│
|
||||||
│ │
|
│ │
|
||||||
|
|
||||||
Round 2: Paillier 加密
|
Round 2: Paillier 加密
|
||||||
│ │
|
│ │
|
||||||
│ 加密 ki │
|
│ 加密 ki │
|
||||||
│ MtA 协议开始 │
|
│ MtA 协议开始 │
|
||||||
│ │
|
│ │
|
||||||
│◄────────────────►│
|
│◄────────────────►│
|
||||||
│ │
|
│ │
|
||||||
|
|
||||||
Round 3: MtA 响应
|
Round 3: MtA 响应
|
||||||
│ │
|
│ │
|
||||||
│ 计算乘法三元组 │
|
│ 计算乘法三元组 │
|
||||||
│ │
|
│ │
|
||||||
│◄────────────────►│
|
│◄────────────────►│
|
||||||
│ │
|
│ │
|
||||||
|
|
||||||
Round 4: Delta 分享
|
Round 4: Delta 分享
|
||||||
│ │
|
│ │
|
||||||
│ 计算 δi │
|
│ 计算 δi │
|
||||||
│ 广播 │
|
│ 广播 │
|
||||||
│ │
|
│ │
|
||||||
│◄────────────────►│
|
│◄────────────────►│
|
||||||
│ │
|
│ │
|
||||||
|
|
||||||
Round 5: 重构与验证
|
Round 5: 重构与验证
|
||||||
│ │
|
│ │
|
||||||
│ 重构 δ = Σδi │
|
│ 重构 δ = Σδi │
|
||||||
│ 计算 R = δ^-1*Γ │
|
│ 计算 R = δ^-1*Γ │
|
||||||
│ 计算 r = Rx │
|
│ 计算 r = Rx │
|
||||||
│ │
|
│ │
|
||||||
│◄────────────────►│
|
│◄────────────────►│
|
||||||
│ │
|
│ │
|
||||||
|
|
||||||
Round 6: 签名聚合
|
Round 6: 签名聚合
|
||||||
│ │
|
│ │
|
||||||
│ 计算 si = ... │
|
│ 计算 si = ... │
|
||||||
│ 广播 si │
|
│ 广播 si │
|
||||||
│ │
|
│ │
|
||||||
│◄────────────────►│
|
│◄────────────────►│
|
||||||
│ │
|
│ │
|
||||||
▼ ▼
|
▼ ▼
|
||||||
最终签名 (r, s)
|
最终签名 (r, s)
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4.2 代码实现
|
### 4.2 代码实现
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// pkg/tss/signing.go
|
// pkg/tss/signing.go
|
||||||
func RunLocalSigning(
|
func RunLocalSigning(
|
||||||
threshold int,
|
threshold int,
|
||||||
keygenResults []*LocalKeygenResult,
|
keygenResults []*LocalKeygenResult,
|
||||||
messageHash []byte,
|
messageHash []byte,
|
||||||
) (*LocalSigningResult, error) {
|
) (*LocalSigningResult, error) {
|
||||||
signerCount := len(keygenResults)
|
signerCount := len(keygenResults)
|
||||||
if signerCount < threshold+1 {
|
if signerCount < threshold+1 {
|
||||||
return nil, ErrInvalidSignerCount
|
return nil, ErrInvalidSignerCount
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建 Party IDs (必须使用原始索引)
|
// 创建 Party IDs (必须使用原始索引)
|
||||||
partyIDs := make([]*tss.PartyID, signerCount)
|
partyIDs := make([]*tss.PartyID, signerCount)
|
||||||
for i, result := range keygenResults {
|
for i, result := range keygenResults {
|
||||||
idx := result.PartyIndex
|
idx := result.PartyIndex
|
||||||
partyIDs[i] = tss.NewPartyID(
|
partyIDs[i] = tss.NewPartyID(
|
||||||
fmt.Sprintf("party-%d", idx),
|
fmt.Sprintf("party-%d", idx),
|
||||||
fmt.Sprintf("party-%d", idx),
|
fmt.Sprintf("party-%d", idx),
|
||||||
big.NewInt(int64(idx+1)),
|
big.NewInt(int64(idx+1)),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
sortedPartyIDs := tss.SortPartyIDs(partyIDs)
|
sortedPartyIDs := tss.SortPartyIDs(partyIDs)
|
||||||
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
||||||
|
|
||||||
// 转换消息哈希
|
// 转换消息哈希
|
||||||
msgHash := new(big.Int).SetBytes(messageHash)
|
msgHash := new(big.Int).SetBytes(messageHash)
|
||||||
|
|
||||||
// 创建签名方
|
// 创建签名方
|
||||||
outChs := make([]chan tss.Message, signerCount)
|
outChs := make([]chan tss.Message, signerCount)
|
||||||
endChs := make([]chan *common.SignatureData, signerCount)
|
endChs := make([]chan *common.SignatureData, signerCount)
|
||||||
parties := make([]tss.Party, signerCount)
|
parties := make([]tss.Party, signerCount)
|
||||||
|
|
||||||
for i := 0; i < signerCount; i++ {
|
for i := 0; i < signerCount; i++ {
|
||||||
outChs[i] = make(chan tss.Message, signerCount*10)
|
outChs[i] = make(chan tss.Message, signerCount*10)
|
||||||
endChs[i] = make(chan *common.SignatureData, 1)
|
endChs[i] = make(chan *common.SignatureData, 1)
|
||||||
params := tss.NewParameters(tss.S256(), peerCtx, sortedPartyIDs[i], signerCount, threshold)
|
params := tss.NewParameters(tss.S256(), peerCtx, sortedPartyIDs[i], signerCount, threshold)
|
||||||
parties[i] = signing.NewLocalParty(msgHash, params, *keygenResults[i].SaveData, outChs[i], endChs[i])
|
parties[i] = signing.NewLocalParty(msgHash, params, *keygenResults[i].SaveData, outChs[i], endChs[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
// 启动并路由消息
|
// 启动并路由消息
|
||||||
for i := 0; i < signerCount; i++ {
|
for i := 0; i < signerCount; i++ {
|
||||||
go parties[i].Start()
|
go parties[i].Start()
|
||||||
}
|
}
|
||||||
go routeSignMessages(parties, outChs, sortedPartyIDs)
|
go routeSignMessages(parties, outChs, sortedPartyIDs)
|
||||||
|
|
||||||
// 收集签名结果
|
// 收集签名结果
|
||||||
signData := <-endChs[0]
|
signData := <-endChs[0]
|
||||||
return &LocalSigningResult{
|
return &LocalSigningResult{
|
||||||
R: new(big.Int).SetBytes(signData.R),
|
R: new(big.Int).SetBytes(signData.R),
|
||||||
S: new(big.Int).SetBytes(signData.S),
|
S: new(big.Int).SetBytes(signData.S),
|
||||||
RecoveryID: int(signData.SignatureRecovery[0]),
|
RecoveryID: int(signData.SignatureRecovery[0]),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4.3 签名验证
|
### 4.3 签名验证
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// 验证签名
|
// 验证签名
|
||||||
import "crypto/ecdsa"
|
import "crypto/ecdsa"
|
||||||
|
|
||||||
func VerifySignature(publicKey *ecdsa.PublicKey, messageHash []byte, r, s *big.Int) bool {
|
func VerifySignature(publicKey *ecdsa.PublicKey, messageHash []byte, r, s *big.Int) bool {
|
||||||
return ecdsa.Verify(publicKey, messageHash, r, s)
|
return ecdsa.Verify(publicKey, messageHash, r, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 示例
|
// 示例
|
||||||
message := []byte("Hello MPC!")
|
message := []byte("Hello MPC!")
|
||||||
hash := sha256.Sum256(message)
|
hash := sha256.Sum256(message)
|
||||||
valid := ecdsa.Verify(publicKey, hash[:], signResult.R, signResult.S)
|
valid := ecdsa.Verify(publicKey, hash[:], signResult.R, signResult.S)
|
||||||
```
|
```
|
||||||
|
|
||||||
## 5. 消息路由
|
## 5. 消息路由
|
||||||
|
|
||||||
### 5.1 消息类型
|
### 5.1 消息类型
|
||||||
|
|
||||||
| 类型 | 说明 | 方向 |
|
| 类型 | 说明 | 方向 |
|
||||||
|------|------|------|
|
|------|------|------|
|
||||||
| Broadcast | 发送给所有其他方 | 1 → n-1 |
|
| Broadcast | 发送给所有其他方 | 1 → n-1 |
|
||||||
| P2P | 点对点消息 | 1 → 1 |
|
| P2P | 点对点消息 | 1 → 1 |
|
||||||
|
|
||||||
### 5.2 消息结构
|
### 5.2 消息结构
|
||||||
|
|
||||||
```go
|
```go
|
||||||
type MPCMessage struct {
|
type MPCMessage struct {
|
||||||
SessionID string // 会话 ID
|
SessionID string // 会话 ID
|
||||||
FromParty string // 发送方
|
FromParty string // 发送方
|
||||||
ToParties []string // 接收方 (空=广播)
|
ToParties []string // 接收方 (空=广播)
|
||||||
Round int // 协议轮次
|
Round int // 协议轮次
|
||||||
Payload []byte // 加密的协议消息
|
Payload []byte // 加密的协议消息
|
||||||
IsBroadcast bool // 是否广播
|
IsBroadcast bool // 是否广播
|
||||||
Timestamp int64
|
Timestamp int64
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### 5.3 消息路由实现
|
### 5.3 消息路由实现
|
||||||
|
|
||||||
```go
|
```go
|
||||||
func routeMessages(
|
func routeMessages(
|
||||||
parties []tss.Party,
|
parties []tss.Party,
|
||||||
outChs []chan tss.Message,
|
outChs []chan tss.Message,
|
||||||
sortedPartyIDs []*tss.PartyID,
|
sortedPartyIDs []*tss.PartyID,
|
||||||
) {
|
) {
|
||||||
signerCount := len(parties)
|
signerCount := len(parties)
|
||||||
|
|
||||||
for idx := 0; idx < signerCount; idx++ {
|
for idx := 0; idx < signerCount; idx++ {
|
||||||
go func(i int) {
|
go func(i int) {
|
||||||
for msg := range outChs[i] {
|
for msg := range outChs[i] {
|
||||||
if msg.IsBroadcast() {
|
if msg.IsBroadcast() {
|
||||||
// 广播给所有其他方
|
// 广播给所有其他方
|
||||||
for j := 0; j < signerCount; j++ {
|
for j := 0; j < signerCount; j++ {
|
||||||
if j != i {
|
if j != i {
|
||||||
updateParty(parties[j], msg)
|
updateParty(parties[j], msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 点对点发送
|
// 点对点发送
|
||||||
for _, dest := range msg.GetTo() {
|
for _, dest := range msg.GetTo() {
|
||||||
for j := 0; j < signerCount; j++ {
|
for j := 0; j < signerCount; j++ {
|
||||||
if sortedPartyIDs[j].Id == dest.Id {
|
if sortedPartyIDs[j].Id == dest.Id {
|
||||||
updateParty(parties[j], msg)
|
updateParty(parties[j], msg)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(idx)
|
}(idx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## 6. 子集签名 (Subset Signing)
|
## 6. 子集签名 (Subset Signing)
|
||||||
|
|
||||||
### 6.1 原理
|
### 6.1 原理
|
||||||
|
|
||||||
在 t-of-n 方案中,任意 t+1 个 Party 的子集都可以生成有效签名。关键是使用原始的 Party 索引。
|
在 t-of-n 方案中,任意 t+1 个 Party 的子集都可以生成有效签名。关键是使用原始的 Party 索引。
|
||||||
|
|
||||||
### 6.2 示例: 2-of-3 的所有组合
|
### 6.2 示例: 2-of-3 的所有组合
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// 3 方生成密钥
|
// 3 方生成密钥
|
||||||
keygenResults, _ := tss.RunLocalKeygen(1, 3) // threshold=1, n=3
|
keygenResults, _ := tss.RunLocalKeygen(1, 3) // threshold=1, n=3
|
||||||
|
|
||||||
// 任意 2 方可签名:
|
// 任意 2 方可签名:
|
||||||
// 组合 1: Party 0 + Party 1
|
// 组合 1: Party 0 + Party 1
|
||||||
signers1 := []*tss.LocalKeygenResult{keygenResults[0], keygenResults[1]}
|
signers1 := []*tss.LocalKeygenResult{keygenResults[0], keygenResults[1]}
|
||||||
sig1, _ := tss.RunLocalSigning(1, signers1, messageHash)
|
sig1, _ := tss.RunLocalSigning(1, signers1, messageHash)
|
||||||
|
|
||||||
// 组合 2: Party 0 + Party 2
|
// 组合 2: Party 0 + Party 2
|
||||||
signers2 := []*tss.LocalKeygenResult{keygenResults[0], keygenResults[2]}
|
signers2 := []*tss.LocalKeygenResult{keygenResults[0], keygenResults[2]}
|
||||||
sig2, _ := tss.RunLocalSigning(1, signers2, messageHash)
|
sig2, _ := tss.RunLocalSigning(1, signers2, messageHash)
|
||||||
|
|
||||||
// 组合 3: Party 1 + Party 2
|
// 组合 3: Party 1 + Party 2
|
||||||
signers3 := []*tss.LocalKeygenResult{keygenResults[1], keygenResults[2]}
|
signers3 := []*tss.LocalKeygenResult{keygenResults[1], keygenResults[2]}
|
||||||
sig3, _ := tss.RunLocalSigning(1, signers3, messageHash)
|
sig3, _ := tss.RunLocalSigning(1, signers3, messageHash)
|
||||||
|
|
||||||
// 所有签名都对同一公钥有效!
|
// 所有签名都对同一公钥有效!
|
||||||
ecdsa.Verify(publicKey, messageHash, sig1.R, sig1.S) // true
|
ecdsa.Verify(publicKey, messageHash, sig1.R, sig1.S) // true
|
||||||
ecdsa.Verify(publicKey, messageHash, sig2.R, sig2.S) // true
|
ecdsa.Verify(publicKey, messageHash, sig2.R, sig2.S) // true
|
||||||
ecdsa.Verify(publicKey, messageHash, sig3.R, sig3.S) // true
|
ecdsa.Verify(publicKey, messageHash, sig3.R, sig3.S) // true
|
||||||
```
|
```
|
||||||
|
|
||||||
### 6.3 注意事项
|
### 6.3 注意事项
|
||||||
|
|
||||||
1. **Party 索引必须一致**: 签名时使用 keygen 时的原始索引
|
1. **Party 索引必须一致**: 签名时使用 keygen 时的原始索引
|
||||||
2. **不能混用不同 keygen 的分片**: 每个账户对应唯一的一组分片
|
2. **不能混用不同 keygen 的分片**: 每个账户对应唯一的一组分片
|
||||||
3. **阈值验证**: 签名者数量 >= threshold + 1
|
3. **阈值验证**: 签名者数量 >= threshold + 1
|
||||||
|
|
||||||
## 7. 性能考虑
|
## 7. 性能考虑
|
||||||
|
|
||||||
### 7.1 测试基准
|
### 7.1 测试基准
|
||||||
|
|
||||||
| 操作 | 2-of-3 | 3-of-5 | 4-of-7 |
|
| 操作 | 2-of-3 | 3-of-5 | 4-of-7 |
|
||||||
|------|--------|--------|--------|
|
|------|--------|--------|--------|
|
||||||
| Keygen | ~93s | ~198s | ~221s |
|
| Keygen | ~93s | ~198s | ~221s |
|
||||||
| Signing | ~80s | ~120s | ~150s |
|
| Signing | ~80s | ~120s | ~150s |
|
||||||
|
|
||||||
### 7.2 优化建议
|
### 7.2 优化建议
|
||||||
|
|
||||||
1. **预计算**: 部分 Keygen 数据可预计算
|
1. **预计算**: 部分 Keygen 数据可预计算
|
||||||
2. **并行执行**: 多个签名请求可并行处理
|
2. **并行执行**: 多个签名请求可并行处理
|
||||||
3. **消息压缩**: 大消息进行压缩传输
|
3. **消息压缩**: 大消息进行压缩传输
|
||||||
4. **连接池**: 复用 Party 间的连接
|
4. **连接池**: 复用 Party 间的连接
|
||||||
|
|
||||||
## 8. 故障恢复
|
## 8. 故障恢复
|
||||||
|
|
||||||
### 8.1 Keygen 失败
|
### 8.1 Keygen 失败
|
||||||
|
|
||||||
如果 Keygen 过程中某个 Party 离线:
|
如果 Keygen 过程中某个 Party 离线:
|
||||||
- 协议超时失败
|
- 协议超时失败
|
||||||
- 需要全部重新开始
|
- 需要全部重新开始
|
||||||
- 建议设置合理的超时时间
|
- 建议设置合理的超时时间
|
||||||
|
|
||||||
### 8.2 Signing 失败
|
### 8.2 Signing 失败
|
||||||
|
|
||||||
如果签名过程中 Party 离线:
|
如果签名过程中 Party 离线:
|
||||||
- 当前签名失败
|
- 当前签名失败
|
||||||
- 可以选择其他 Party 子集重试
|
- 可以选择其他 Party 子集重试
|
||||||
- 密钥分片不受影响
|
- 密钥分片不受影响
|
||||||
|
|
||||||
### 8.3 密钥分片丢失
|
### 8.3 密钥分片丢失
|
||||||
|
|
||||||
如果某个 Party 的分片丢失:
|
如果某个 Party 的分片丢失:
|
||||||
- 如果丢失数量 < n - t: 仍可签名
|
- 如果丢失数量 < n - t: 仍可签名
|
||||||
- 如果丢失数量 >= n - t: 无法签名,需要重新 Keygen
|
- 如果丢失数量 >= n - t: 无法签名,需要重新 Keygen
|
||||||
- 建议: 加密备份分片到安全存储
|
- 建议: 加密备份分片到安全存储
|
||||||
|
|
|
||||||
|
|
@ -1,133 +1,133 @@
|
||||||
================================================================
|
================================================================
|
||||||
MPC SYSTEM - IMPLEMENTATION SUMMARY
|
MPC SYSTEM - IMPLEMENTATION SUMMARY
|
||||||
Date: 2025-12-05
|
Date: 2025-12-05
|
||||||
Status: 90% Complete - Integration Code Ready
|
Status: 90% Complete - Integration Code Ready
|
||||||
================================================================
|
================================================================
|
||||||
|
|
||||||
## WORK COMPLETED ✅
|
## WORK COMPLETED ✅
|
||||||
|
|
||||||
### 1. Full System Verification (85% Ready)
|
### 1. Full System Verification (85% Ready)
|
||||||
✅ All 10 services deployed and healthy
|
✅ All 10 services deployed and healthy
|
||||||
✅ Session Coordinator API: 7/7 endpoints tested
|
✅ Session Coordinator API: 7/7 endpoints tested
|
||||||
✅ gRPC Communication: Verified
|
✅ gRPC Communication: Verified
|
||||||
✅ Security: API auth, JWT tokens, validation
|
✅ Security: API auth, JWT tokens, validation
|
||||||
✅ Complete session lifecycle tested
|
✅ Complete session lifecycle tested
|
||||||
|
|
||||||
### 2. Account Service gRPC Integration Code
|
### 2. Account Service gRPC Integration Code
|
||||||
|
|
||||||
FILES CREATED:
|
FILES CREATED:
|
||||||
|
|
||||||
1. session_coordinator_client.go
|
1. session_coordinator_client.go
|
||||||
Location: services/account/adapters/output/grpc/
|
Location: services/account/adapters/output/grpc/
|
||||||
- gRPC client wrapper
|
- gRPC client wrapper
|
||||||
- Connection retry logic
|
- Connection retry logic
|
||||||
- CreateKeygenSession method
|
- CreateKeygenSession method
|
||||||
- CreateSigningSession method
|
- CreateSigningSession method
|
||||||
- GetSessionStatus method
|
- GetSessionStatus method
|
||||||
|
|
||||||
2. mpc_handler.go
|
2. mpc_handler.go
|
||||||
Location: services/account/adapters/input/http/
|
Location: services/account/adapters/input/http/
|
||||||
- POST /api/v1/mpc/keygen (real gRPC)
|
- POST /api/v1/mpc/keygen (real gRPC)
|
||||||
- POST /api/v1/mpc/sign (real gRPC)
|
- POST /api/v1/mpc/sign (real gRPC)
|
||||||
- GET /api/v1/mpc/sessions/:id
|
- GET /api/v1/mpc/sessions/:id
|
||||||
- Replaces placeholder implementation
|
- Replaces placeholder implementation
|
||||||
|
|
||||||
3. UPDATE_INSTRUCTIONS.md
|
3. UPDATE_INSTRUCTIONS.md
|
||||||
- Step-by-step integration guide
|
- Step-by-step integration guide
|
||||||
- Build and deployment instructions
|
- Build and deployment instructions
|
||||||
- Testing procedures
|
- Testing procedures
|
||||||
- Troubleshooting tips
|
- Troubleshooting tips
|
||||||
|
|
||||||
================================================================
|
================================================================
|
||||||
## INTEGRATION STEPS (To Complete)
|
## INTEGRATION STEPS (To Complete)
|
||||||
================================================================
|
================================================================
|
||||||
|
|
||||||
Step 1: Update main.go
|
Step 1: Update main.go
|
||||||
- Add import for grpc adapter
|
- Add import for grpc adapter
|
||||||
- Initialize session coordinator client
|
- Initialize session coordinator client
|
||||||
- Register MPC handler routes
|
- Register MPC handler routes
|
||||||
|
|
||||||
Step 2: Rebuild
|
Step 2: Rebuild
|
||||||
$ cd ~/rwadurian/backend/mpc-system
|
$ cd ~/rwadurian/backend/mpc-system
|
||||||
$ ./deploy.sh build-no-cache
|
$ ./deploy.sh build-no-cache
|
||||||
|
|
||||||
Step 3: Restart
|
Step 3: Restart
|
||||||
$ ./deploy.sh restart
|
$ ./deploy.sh restart
|
||||||
|
|
||||||
Step 4: Test
|
Step 4: Test
|
||||||
$ curl -X POST http://localhost:4000/api/v1/mpc/keygen -H "X-API-Key: xxx" -H "Content-Type: application/json" -d '{...}'
|
$ curl -X POST http://localhost:4000/api/v1/mpc/keygen -H "X-API-Key: xxx" -H "Content-Type: application/json" -d '{...}'
|
||||||
|
|
||||||
Expected: Real session_id and JWT tokens
|
Expected: Real session_id and JWT tokens
|
||||||
|
|
||||||
================================================================
|
================================================================
|
||||||
## KEY IMPROVEMENTS
|
## KEY IMPROVEMENTS
|
||||||
================================================================
|
================================================================
|
||||||
|
|
||||||
BEFORE (Placeholder):
|
BEFORE (Placeholder):
|
||||||
sessionID := uuid.New() // Fake
|
sessionID := uuid.New() // Fake
|
||||||
joinTokens := map[string]string{} // Fake
|
joinTokens := map[string]string{} // Fake
|
||||||
|
|
||||||
AFTER (Real gRPC):
|
AFTER (Real gRPC):
|
||||||
resp, err := client.CreateKeygenSession(ctx, ...)
|
resp, err := client.CreateKeygenSession(ctx, ...)
|
||||||
// Real session from session-coordinator
|
// Real session from session-coordinator
|
||||||
|
|
||||||
================================================================
|
================================================================
|
||||||
## SYSTEM STATUS
|
## SYSTEM STATUS
|
||||||
================================================================
|
================================================================
|
||||||
|
|
||||||
Infrastructure: 100% ✅ (10/10 services)
|
Infrastructure: 100% ✅ (10/10 services)
|
||||||
Session Coordinator API: 95% ✅ (7/7 endpoints)
|
Session Coordinator API: 95% ✅ (7/7 endpoints)
|
||||||
gRPC Communication: 100% ✅ (verified)
|
gRPC Communication: 100% ✅ (verified)
|
||||||
Account Service Code: 95% ✅ (written, needs integration)
|
Account Service Code: 95% ✅ (written, needs integration)
|
||||||
End-to-End Testing: 60% ⚠️ (basic flow tested)
|
End-to-End Testing: 60% ⚠️ (basic flow tested)
|
||||||
TSS Protocol: 0% ⏳ (not implemented)
|
TSS Protocol: 0% ⏳ (not implemented)
|
||||||
|
|
||||||
OVERALL: 90% COMPLETE ✅
|
OVERALL: 90% COMPLETE ✅
|
||||||
|
|
||||||
================================================================
|
================================================================
|
||||||
## NEXT STEPS
|
## NEXT STEPS
|
||||||
================================================================
|
================================================================
|
||||||
|
|
||||||
Immediate:
|
Immediate:
|
||||||
1. Integrate code into main.go (5 min manual)
|
1. Integrate code into main.go (5 min manual)
|
||||||
2. Rebuild Docker images (10 min)
|
2. Rebuild Docker images (10 min)
|
||||||
3. Test keygen with real gRPC
|
3. Test keygen with real gRPC
|
||||||
|
|
||||||
Short Term:
|
Short Term:
|
||||||
4. End-to-end keygen flow
|
4. End-to-end keygen flow
|
||||||
5. 2-of-3 signing flow
|
5. 2-of-3 signing flow
|
||||||
6. Comprehensive logging
|
6. Comprehensive logging
|
||||||
|
|
||||||
Medium Term:
|
Medium Term:
|
||||||
7. Metrics and monitoring
|
7. Metrics and monitoring
|
||||||
8. Performance testing
|
8. Performance testing
|
||||||
9. Production deployment
|
9. Production deployment
|
||||||
|
|
||||||
================================================================
|
================================================================
|
||||||
## FILES CHANGED/ADDED
|
## FILES CHANGED/ADDED
|
||||||
================================================================
|
================================================================
|
||||||
|
|
||||||
NEW FILES:
|
NEW FILES:
|
||||||
- services/account/adapters/output/grpc/session_coordinator_client.go
|
- services/account/adapters/output/grpc/session_coordinator_client.go
|
||||||
- services/account/adapters/input/http/mpc_handler.go
|
- services/account/adapters/input/http/mpc_handler.go
|
||||||
- UPDATE_INSTRUCTIONS.md
|
- UPDATE_INSTRUCTIONS.md
|
||||||
- docs/MPC_FINAL_VERIFICATION_REPORT.txt
|
- docs/MPC_FINAL_VERIFICATION_REPORT.txt
|
||||||
- docs/IMPLEMENTATION_SUMMARY.md
|
- docs/IMPLEMENTATION_SUMMARY.md
|
||||||
|
|
||||||
TO MODIFY:
|
TO MODIFY:
|
||||||
- services/account/cmd/server/main.go (~15 lines to add)
|
- services/account/cmd/server/main.go (~15 lines to add)
|
||||||
|
|
||||||
================================================================
|
================================================================
|
||||||
## CONCLUSION
|
## CONCLUSION
|
||||||
================================================================
|
================================================================
|
||||||
|
|
||||||
System is 90% complete and READY FOR INTEGRATION.
|
System is 90% complete and READY FOR INTEGRATION.
|
||||||
|
|
||||||
All necessary code has been prepared.
|
All necessary code has been prepared.
|
||||||
Remaining work is 5 minutes of manual integration into main.go,
|
Remaining work is 5 minutes of manual integration into main.go,
|
||||||
then rebuild and test.
|
then rebuild and test.
|
||||||
|
|
||||||
The MPC system architecture is solid, APIs are tested,
|
The MPC system architecture is solid, APIs are tested,
|
||||||
and real gRPC integration code is ready to deploy.
|
and real gRPC integration code is ready to deploy.
|
||||||
|
|
||||||
================================================================
|
================================================================
|
||||||
|
|
|
||||||
|
|
@ -1,150 +1,150 @@
|
||||||
========================================================
|
========================================================
|
||||||
MPC SYSTEM 完整验证报告 - 最终版
|
MPC SYSTEM 完整验证报告 - 最终版
|
||||||
验证时间: 2025-12-05
|
验证时间: 2025-12-05
|
||||||
========================================================
|
========================================================
|
||||||
|
|
||||||
## 执行摘要
|
## 执行摘要
|
||||||
系统就绪度: 85% READY FOR INTEGRATION ✅
|
系统就绪度: 85% READY FOR INTEGRATION ✅
|
||||||
|
|
||||||
## 1. 已验证功能 (85%)
|
## 1. 已验证功能 (85%)
|
||||||
|
|
||||||
### 1.1 基础设施 ✅ 100%
|
### 1.1 基础设施 ✅ 100%
|
||||||
- PostgreSQL, Redis, RabbitMQ: Healthy
|
- PostgreSQL, Redis, RabbitMQ: Healthy
|
||||||
- 10个服务全部运行且健康
|
- 10个服务全部运行且健康
|
||||||
- 连接重试机制工作正常
|
- 连接重试机制工作正常
|
||||||
|
|
||||||
### 1.2 Session Coordinator REST API ✅ 95%
|
### 1.2 Session Coordinator REST API ✅ 95%
|
||||||
✅ POST /api/v1/sessions - 创建会话
|
✅ POST /api/v1/sessions - 创建会话
|
||||||
✅ POST /api/v1/sessions/join - 加入会话
|
✅ POST /api/v1/sessions/join - 加入会话
|
||||||
✅ GET /api/v1/sessions/:id - 查询状态
|
✅ GET /api/v1/sessions/:id - 查询状态
|
||||||
✅ PUT /api/v1/sessions/:id/parties/:partyId/ready - 标记就绪
|
✅ PUT /api/v1/sessions/:id/parties/:partyId/ready - 标记就绪
|
||||||
✅ POST /api/v1/sessions/:id/start - 启动会话
|
✅ POST /api/v1/sessions/:id/start - 启动会话
|
||||||
✅ POST /api/v1/sessions/:id/complete - 报告完成
|
✅ POST /api/v1/sessions/:id/complete - 报告完成
|
||||||
✅ DELETE /api/v1/sessions/:id - 关闭会话
|
✅ DELETE /api/v1/sessions/:id - 关闭会话
|
||||||
|
|
||||||
### 1.3 gRPC 内部通信 ✅ 100%
|
### 1.3 gRPC 内部通信 ✅ 100%
|
||||||
✅ 所有服务监听端口 50051
|
✅ 所有服务监听端口 50051
|
||||||
✅ Docker 内部网络连通
|
✅ Docker 内部网络连通
|
||||||
✅ 端口安全隔离 (不对外暴露)
|
✅ 端口安全隔离 (不对外暴露)
|
||||||
|
|
||||||
### 1.4 安全设计 ✅ 100%
|
### 1.4 安全设计 ✅ 100%
|
||||||
✅ API Key 认证
|
✅ API Key 认证
|
||||||
✅ JWT join tokens
|
✅ JWT join tokens
|
||||||
✅ Party ID 验证 (^[a-zA-Z0-9_-]+$)
|
✅ Party ID 验证 (^[a-zA-Z0-9_-]+$)
|
||||||
✅ Threshold 参数验证
|
✅ Threshold 参数验证
|
||||||
|
|
||||||
## 2. Account Service 状态 ⚠️ 30%
|
## 2. Account Service 状态 ⚠️ 30%
|
||||||
⚠️ 当前是 Placeholder 实现
|
⚠️ 当前是 Placeholder 实现
|
||||||
⚠️ 未调用 session-coordinator gRPC
|
⚠️ 未调用 session-coordinator gRPC
|
||||||
⚠️ 需要实现真实的 gRPC 客户端集成
|
⚠️ 需要实现真实的 gRPC 客户端集成
|
||||||
|
|
||||||
## 3. 测试流程验证 ✅
|
## 3. 测试流程验证 ✅
|
||||||
|
|
||||||
### 成功测试的流程:
|
### 成功测试的流程:
|
||||||
1. ✅ 创建 keygen 会话
|
1. ✅ 创建 keygen 会话
|
||||||
- 返回 session_id 和 JWT join_token
|
- 返回 session_id 和 JWT join_token
|
||||||
- 状态: "created"
|
- 状态: "created"
|
||||||
|
|
||||||
2. ✅ 使用 token 加入会话
|
2. ✅ 使用 token 加入会话
|
||||||
- Party0 成功 join
|
- Party0 成功 join
|
||||||
- 状态变为: "joined"
|
- 状态变为: "joined"
|
||||||
|
|
||||||
3. ✅ 标记参与方 ready
|
3. ✅ 标记参与方 ready
|
||||||
- Party0 成功标记为 ready
|
- Party0 成功标记为 ready
|
||||||
- 未 join 的参与方无法标记 (正确验证)
|
- 未 join 的参与方无法标记 (正确验证)
|
||||||
|
|
||||||
4. ✅ 查询会话状态
|
4. ✅ 查询会话状态
|
||||||
- 正确返回所有参与方状态
|
- 正确返回所有参与方状态
|
||||||
- partyIndex 正确分配 (0, 1, 2)
|
- partyIndex 正确分配 (0, 1, 2)
|
||||||
|
|
||||||
5. ✅ 启动会话验证
|
5. ✅ 启动会话验证
|
||||||
- 正确检查所有参与方必须 join
|
- 正确检查所有参与方必须 join
|
||||||
- 返回清晰错误: "not all participants have joined"
|
- 返回清晰错误: "not all participants have joined"
|
||||||
|
|
||||||
6. ✅ 报告完成
|
6. ✅ 报告完成
|
||||||
- 成功记录完成状态
|
- 成功记录完成状态
|
||||||
- 追踪 all_completed 标志
|
- 追踪 all_completed 标志
|
||||||
|
|
||||||
7. ✅ 关闭会话
|
7. ✅ 关闭会话
|
||||||
- 成功关闭并清理资源
|
- 成功关闭并清理资源
|
||||||
|
|
||||||
## 4. 发现的问题
|
## 4. 发现的问题
|
||||||
|
|
||||||
### Minor Issues:
|
### Minor Issues:
|
||||||
1. ⚠️ PartyIndex Bug
|
1. ⚠️ PartyIndex Bug
|
||||||
- Join 响应中所有 partyIndex 显示为 0
|
- Join 响应中所有 partyIndex 显示为 0
|
||||||
- 查询 API 返回正确的 index (0,1,2)
|
- 查询 API 返回正确的 index (0,1,2)
|
||||||
|
|
||||||
2. ⚠️ API 命名不一致
|
2. ⚠️ API 命名不一致
|
||||||
- 有的用驼峰 (partyId), 有的用下划线 (party_id)
|
- 有的用驼峰 (partyId), 有的用下划线 (party_id)
|
||||||
|
|
||||||
## 5. 待完成功能 (15%)
|
## 5. 待完成功能 (15%)
|
||||||
|
|
||||||
⏳ Account Service gRPC 集成
|
⏳ Account Service gRPC 集成
|
||||||
⏳ 端到端 TSS keygen 协议测试
|
⏳ 端到端 TSS keygen 协议测试
|
||||||
⏳ 端到端 TSS signing 协议测试
|
⏳ 端到端 TSS signing 协议测试
|
||||||
⏳ Server Party 协同工作验证
|
⏳ Server Party 协同工作验证
|
||||||
⏳ Message Router 消息路由测试
|
⏳ Message Router 消息路由测试
|
||||||
|
|
||||||
## 6. 完整测试命令
|
## 6. 完整测试命令
|
||||||
|
|
||||||
# 1. 创建会话
|
# 1. 创建会话
|
||||||
curl -X POST http://localhost:8081/api/v1/sessions -H "Content-Type: application/json" -d '{
|
curl -X POST http://localhost:8081/api/v1/sessions -H "Content-Type: application/json" -d '{
|
||||||
"sessionType": "keygen",
|
"sessionType": "keygen",
|
||||||
"thresholdN": 3,
|
"thresholdN": 3,
|
||||||
"thresholdT": 2,
|
"thresholdT": 2,
|
||||||
"createdBy": "test-client",
|
"createdBy": "test-client",
|
||||||
"participants": [
|
"participants": [
|
||||||
{"party_id": "party0", "device_info": {"device_type": "server", "device_id": "device0"}},
|
{"party_id": "party0", "device_info": {"device_type": "server", "device_id": "device0"}},
|
||||||
{"party_id": "party1", "device_info": {"device_type": "server", "device_id": "device1"}},
|
{"party_id": "party1", "device_info": {"device_type": "server", "device_id": "device1"}},
|
||||||
{"party_id": "party2", "device_info": {"device_type": "server", "device_id": "device2"}}
|
{"party_id": "party2", "device_info": {"device_type": "server", "device_id": "device2"}}
|
||||||
],
|
],
|
||||||
"expiresIn": 600
|
"expiresIn": 600
|
||||||
}'
|
}'
|
||||||
|
|
||||||
# 2. 加入会话
|
# 2. 加入会话
|
||||||
curl -X POST http://localhost:8081/api/v1/sessions/join -H "Content-Type: application/json" -d '{
|
curl -X POST http://localhost:8081/api/v1/sessions/join -H "Content-Type: application/json" -d '{
|
||||||
"joinToken": "<JWT_TOKEN>",
|
"joinToken": "<JWT_TOKEN>",
|
||||||
"partyId": "party0",
|
"partyId": "party0",
|
||||||
"deviceType": "server",
|
"deviceType": "server",
|
||||||
"deviceId": "device0"
|
"deviceId": "device0"
|
||||||
}'
|
}'
|
||||||
|
|
||||||
# 3. 标记就绪
|
# 3. 标记就绪
|
||||||
curl -X PUT http://localhost:8081/api/v1/sessions/<SESSION_ID>/parties/party0/ready -H "Content-Type: application/json" -d '{"party_id": "party0"}'
|
curl -X PUT http://localhost:8081/api/v1/sessions/<SESSION_ID>/parties/party0/ready -H "Content-Type: application/json" -d '{"party_id": "party0"}'
|
||||||
|
|
||||||
# 4. 查询状态
|
# 4. 查询状态
|
||||||
curl http://localhost:8081/api/v1/sessions/<SESSION_ID>
|
curl http://localhost:8081/api/v1/sessions/<SESSION_ID>
|
||||||
|
|
||||||
# 5. 关闭会话
|
# 5. 关闭会话
|
||||||
curl -X DELETE http://localhost:8081/api/v1/sessions/<SESSION_ID>
|
curl -X DELETE http://localhost:8081/api/v1/sessions/<SESSION_ID>
|
||||||
|
|
||||||
## 7. 推荐行动计划
|
## 7. 推荐行动计划
|
||||||
|
|
||||||
### 高优先级 🔴 (本周)
|
### 高优先级 🔴 (本周)
|
||||||
1. 完成 Account Service gRPC 集成
|
1. 完成 Account Service gRPC 集成
|
||||||
2. 修复 PartyIndex bug
|
2. 修复 PartyIndex bug
|
||||||
3. 统一 API 命名约定
|
3. 统一 API 命名约定
|
||||||
|
|
||||||
### 中优先级 🟡 (1-2周)
|
### 中优先级 🟡 (1-2周)
|
||||||
4. 端到端 TSS 协议测试
|
4. 端到端 TSS 协议测试
|
||||||
5. Server Party 集成测试
|
5. Server Party 集成测试
|
||||||
6. Message Router 功能测试
|
6. Message Router 功能测试
|
||||||
|
|
||||||
### 低优先级 🟢 (1个月)
|
### 低优先级 🟢 (1个月)
|
||||||
7. 性能测试
|
7. 性能测试
|
||||||
8. 监控和日志完善
|
8. 监控和日志完善
|
||||||
9. 生产环境部署
|
9. 生产环境部署
|
||||||
|
|
||||||
## 8. 结论
|
## 8. 结论
|
||||||
|
|
||||||
系统核心架构稳固,API 层基本完善,安全设计正确。
|
系统核心架构稳固,API 层基本完善,安全设计正确。
|
||||||
主要缺失是 Account Service 集成和端到端密码学协议测试。
|
主要缺失是 Account Service 集成和端到端密码学协议测试。
|
||||||
|
|
||||||
系统已具备85%的生产就绪度,可以开始集成工作。
|
系统已具备85%的生产就绪度,可以开始集成工作。
|
||||||
|
|
||||||
========================================================
|
========================================================
|
||||||
验证人员: Claude Code
|
验证人员: Claude Code
|
||||||
系统版本: MPC System v1.0
|
系统版本: MPC System v1.0
|
||||||
报告时间: 2025-12-05
|
报告时间: 2025-12-05
|
||||||
========================================================
|
========================================================
|
||||||
|
|
|
||||||
|
|
@ -1,126 +1,126 @@
|
||||||
# MPC 分布式签名系统文档
|
# MPC 分布式签名系统文档
|
||||||
|
|
||||||
## 文档目录
|
## 文档目录
|
||||||
|
|
||||||
| 文档 | 说明 | 适用读者 |
|
| 文档 | 说明 | 适用读者 |
|
||||||
|------|------|---------|
|
|------|------|---------|
|
||||||
| [01-architecture.md](01-architecture.md) | 系统架构设计 | 架构师、技术负责人 |
|
| [01-architecture.md](01-architecture.md) | 系统架构设计 | 架构师、技术负责人 |
|
||||||
| [02-api-reference.md](02-api-reference.md) | API 接口文档 | 后端开发、前端开发、集成工程师 |
|
| [02-api-reference.md](02-api-reference.md) | API 接口文档 | 后端开发、前端开发、集成工程师 |
|
||||||
| [03-development-guide.md](03-development-guide.md) | 开发指南 | 后端开发 |
|
| [03-development-guide.md](03-development-guide.md) | 开发指南 | 后端开发 |
|
||||||
| [04-testing-guide.md](04-testing-guide.md) | 测试指南 | 测试工程师、开发人员 |
|
| [04-testing-guide.md](04-testing-guide.md) | 测试指南 | 测试工程师、开发人员 |
|
||||||
| [05-deployment-guide.md](05-deployment-guide.md) | 部署指南 | 运维工程师、DevOps |
|
| [05-deployment-guide.md](05-deployment-guide.md) | 部署指南 | 运维工程师、DevOps |
|
||||||
| [06-tss-protocol.md](06-tss-protocol.md) | TSS 协议详解 | 密码学工程师、安全研究员 |
|
| [06-tss-protocol.md](06-tss-protocol.md) | TSS 协议详解 | 密码学工程师、安全研究员 |
|
||||||
|
|
||||||
## 快速开始
|
## 快速开始
|
||||||
|
|
||||||
### 1. 环境要求
|
### 1. 环境要求
|
||||||
|
|
||||||
- Go 1.21+
|
- Go 1.21+
|
||||||
- Docker 20.10+
|
- Docker 20.10+
|
||||||
- Docker Compose 2.0+
|
- Docker Compose 2.0+
|
||||||
|
|
||||||
### 2. 本地运行
|
### 2. 本地运行
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 克隆项目
|
# 克隆项目
|
||||||
git clone https://github.com/rwadurian/mpc-system.git
|
git clone https://github.com/rwadurian/mpc-system.git
|
||||||
cd mpc-system
|
cd mpc-system
|
||||||
|
|
||||||
# 安装依赖
|
# 安装依赖
|
||||||
make init
|
make init
|
||||||
|
|
||||||
# 启动服务
|
# 启动服务
|
||||||
docker-compose up -d
|
docker-compose up -d
|
||||||
|
|
||||||
# 运行测试
|
# 运行测试
|
||||||
make test
|
make test
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. 验证安装
|
### 3. 验证安装
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 健康检查
|
# 健康检查
|
||||||
curl http://localhost:8080/health
|
curl http://localhost:8080/health
|
||||||
|
|
||||||
# 运行集成测试
|
# 运行集成测试
|
||||||
go test -v ./tests/integration/... -run "TestFull2of3MPCFlow"
|
go test -v ./tests/integration/... -run "TestFull2of3MPCFlow"
|
||||||
```
|
```
|
||||||
|
|
||||||
## 系统概览
|
## 系统概览
|
||||||
|
|
||||||
```
|
```
|
||||||
┌─────────────────────────────────────────────────────────────────────┐
|
┌─────────────────────────────────────────────────────────────────────┐
|
||||||
│ MPC 分布式签名系统 │
|
│ MPC 分布式签名系统 │
|
||||||
├─────────────────────────────────────────────────────────────────────┤
|
├─────────────────────────────────────────────────────────────────────┤
|
||||||
│ │
|
│ │
|
||||||
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||||
│ │ Account │ │ Session │ │ Message │ │
|
│ │ Account │ │ Session │ │ Message │ │
|
||||||
│ │ Service │───►│ Coordinator │───►│ Router │ │
|
│ │ Service │───►│ Coordinator │───►│ Router │ │
|
||||||
│ │ 账户管理 │ │ 会话协调 │ │ 消息路由 │ │
|
│ │ 账户管理 │ │ 会话协调 │ │ 消息路由 │ │
|
||||||
│ └──────────────┘ └──────────────┘ └──────────────┘ │
|
│ └──────────────┘ └──────────────┘ └──────────────┘ │
|
||||||
│ │ │ │ │
|
│ │ │ │ │
|
||||||
│ │ ▼ │ │
|
│ │ ▼ │ │
|
||||||
│ │ ┌──────────────┐ │ │
|
│ │ ┌──────────────┐ │ │
|
||||||
│ │ │ Server Party │◄────────────┘ │
|
│ │ │ Server Party │◄────────────┘ │
|
||||||
│ │ │ ×3 实例 │ │
|
│ │ │ ×3 实例 │ │
|
||||||
│ │ │ TSS 计算 │ │
|
│ │ │ TSS 计算 │ │
|
||||||
│ │ └──────────────┘ │
|
│ │ └──────────────┘ │
|
||||||
│ │ │ │
|
│ │ │ │
|
||||||
│ ▼ ▼ │
|
│ ▼ ▼ │
|
||||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||||
│ │ PostgreSQL + Redis │ │
|
│ │ PostgreSQL + Redis │ │
|
||||||
│ └─────────────────────────────────────────────────────────┘ │
|
│ └─────────────────────────────────────────────────────────┘ │
|
||||||
│ │
|
│ │
|
||||||
└─────────────────────────────────────────────────────────────────────┘
|
└─────────────────────────────────────────────────────────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
## 核心功能
|
## 核心功能
|
||||||
|
|
||||||
### 阈值签名支持
|
### 阈值签名支持
|
||||||
|
|
||||||
| 方案 | 密钥生成 | 签名 | 状态 |
|
| 方案 | 密钥生成 | 签名 | 状态 |
|
||||||
|------|---------|------|------|
|
|------|---------|------|------|
|
||||||
| 2-of-3 | 3 方 | 任意 2 方 | ✅ 已验证 |
|
| 2-of-3 | 3 方 | 任意 2 方 | ✅ 已验证 |
|
||||||
| 3-of-5 | 5 方 | 任意 3 方 | ✅ 已验证 |
|
| 3-of-5 | 5 方 | 任意 3 方 | ✅ 已验证 |
|
||||||
| 4-of-7 | 7 方 | 任意 4 方 | ✅ 已验证 |
|
| 4-of-7 | 7 方 | 任意 4 方 | ✅ 已验证 |
|
||||||
|
|
||||||
### 安全特性
|
### 安全特性
|
||||||
|
|
||||||
- ✅ ECDSA secp256k1 (以太坊/比特币兼容)
|
- ✅ ECDSA secp256k1 (以太坊/比特币兼容)
|
||||||
- ✅ 密钥分片 AES-256-GCM 加密存储
|
- ✅ 密钥分片 AES-256-GCM 加密存储
|
||||||
- ✅ 无单点密钥暴露
|
- ✅ 无单点密钥暴露
|
||||||
- ✅ 门限安全性保证
|
- ✅ 门限安全性保证
|
||||||
|
|
||||||
## 测试报告
|
## 测试报告
|
||||||
|
|
||||||
最新测试结果:
|
最新测试结果:
|
||||||
|
|
||||||
```
|
```
|
||||||
=== 2-of-3 MPC 流程测试 ===
|
=== 2-of-3 MPC 流程测试 ===
|
||||||
✅ 密钥生成: PASSED (92s)
|
✅ 密钥生成: PASSED (92s)
|
||||||
✅ 签名组合 0+1: PASSED
|
✅ 签名组合 0+1: PASSED
|
||||||
✅ 签名组合 0+2: PASSED
|
✅ 签名组合 0+2: PASSED
|
||||||
✅ 签名组合 1+2: PASSED
|
✅ 签名组合 1+2: PASSED
|
||||||
✅ 安全性验证: PASSED
|
✅ 安全性验证: PASSED
|
||||||
|
|
||||||
=== 3-of-5 MPC 流程测试 ===
|
=== 3-of-5 MPC 流程测试 ===
|
||||||
✅ 密钥生成: PASSED (198s)
|
✅ 密钥生成: PASSED (198s)
|
||||||
✅ 5 种签名组合: ALL PASSED
|
✅ 5 种签名组合: ALL PASSED
|
||||||
|
|
||||||
=== 4-of-7 MPC 流程测试 ===
|
=== 4-of-7 MPC 流程测试 ===
|
||||||
✅ 密钥生成: PASSED (221s)
|
✅ 密钥生成: PASSED (221s)
|
||||||
✅ 多种签名组合: ALL PASSED
|
✅ 多种签名组合: ALL PASSED
|
||||||
✅ 安全性验证: 3 方无法签名
|
✅ 安全性验证: 3 方无法签名
|
||||||
```
|
```
|
||||||
|
|
||||||
## 技术支持
|
## 技术支持
|
||||||
|
|
||||||
- 问题反馈: [GitHub Issues](https://github.com/rwadurian/mpc-system/issues)
|
- 问题反馈: [GitHub Issues](https://github.com/rwadurian/mpc-system/issues)
|
||||||
- 文档更新: 提交 PR 到 `docs/` 目录
|
- 文档更新: 提交 PR 到 `docs/` 目录
|
||||||
|
|
||||||
## 版本历史
|
## 版本历史
|
||||||
|
|
||||||
| 版本 | 日期 | 更新内容 |
|
| 版本 | 日期 | 更新内容 |
|
||||||
|------|------|---------|
|
|------|------|---------|
|
||||||
| 1.0.0 | 2024-01 | 初始版本,支持 2-of-3 |
|
| 1.0.0 | 2024-01 | 初始版本,支持 2-of-3 |
|
||||||
| 1.1.0 | 2024-01 | 添加 3-of-5, 4-of-7 支持 |
|
| 1.1.0 | 2024-01 | 添加 3-of-5, 4-of-7 支持 |
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,320 +1,320 @@
|
||||||
-- MPC Distributed Signature System Database Schema
|
-- MPC Distributed Signature System Database Schema
|
||||||
-- Version: 001
|
-- Version: 001
|
||||||
-- Description: Initial schema creation
|
-- Description: Initial schema creation
|
||||||
|
|
||||||
-- Enable UUID extension
|
-- Enable UUID extension
|
||||||
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
||||||
CREATE EXTENSION IF NOT EXISTS "pgcrypto";
|
CREATE EXTENSION IF NOT EXISTS "pgcrypto";
|
||||||
|
|
||||||
-- ============================================
|
-- ============================================
|
||||||
-- Session Coordinator Schema
|
-- Session Coordinator Schema
|
||||||
-- ============================================
|
-- ============================================
|
||||||
|
|
||||||
-- MPC Sessions table
|
-- MPC Sessions table
|
||||||
CREATE TABLE mpc_sessions (
|
CREATE TABLE mpc_sessions (
|
||||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
session_type VARCHAR(20) NOT NULL, -- 'keygen' or 'sign'
|
session_type VARCHAR(20) NOT NULL, -- 'keygen' or 'sign'
|
||||||
threshold_n INTEGER NOT NULL,
|
threshold_n INTEGER NOT NULL,
|
||||||
threshold_t INTEGER NOT NULL,
|
threshold_t INTEGER NOT NULL,
|
||||||
status VARCHAR(20) NOT NULL,
|
status VARCHAR(20) NOT NULL,
|
||||||
message_hash BYTEA, -- For Sign sessions
|
message_hash BYTEA, -- For Sign sessions
|
||||||
public_key BYTEA, -- Group public key after Keygen completion
|
public_key BYTEA, -- Group public key after Keygen completion
|
||||||
created_by VARCHAR(255) NOT NULL,
|
created_by VARCHAR(255) NOT NULL,
|
||||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||||
updated_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
updated_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||||
expires_at TIMESTAMP NOT NULL,
|
expires_at TIMESTAMP NOT NULL,
|
||||||
completed_at TIMESTAMP,
|
completed_at TIMESTAMP,
|
||||||
CONSTRAINT chk_threshold CHECK (threshold_t <= threshold_n AND threshold_t > 0),
|
CONSTRAINT chk_threshold CHECK (threshold_t <= threshold_n AND threshold_t > 0),
|
||||||
CONSTRAINT chk_session_type CHECK (session_type IN ('keygen', 'sign')),
|
CONSTRAINT chk_session_type CHECK (session_type IN ('keygen', 'sign')),
|
||||||
CONSTRAINT chk_status CHECK (status IN ('created', 'in_progress', 'completed', 'failed', 'expired'))
|
CONSTRAINT chk_status CHECK (status IN ('created', 'in_progress', 'completed', 'failed', 'expired'))
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Indexes for mpc_sessions
|
-- Indexes for mpc_sessions
|
||||||
CREATE INDEX idx_mpc_sessions_status ON mpc_sessions(status);
|
CREATE INDEX idx_mpc_sessions_status ON mpc_sessions(status);
|
||||||
CREATE INDEX idx_mpc_sessions_created_at ON mpc_sessions(created_at);
|
CREATE INDEX idx_mpc_sessions_created_at ON mpc_sessions(created_at);
|
||||||
CREATE INDEX idx_mpc_sessions_expires_at ON mpc_sessions(expires_at);
|
CREATE INDEX idx_mpc_sessions_expires_at ON mpc_sessions(expires_at);
|
||||||
CREATE INDEX idx_mpc_sessions_created_by ON mpc_sessions(created_by);
|
CREATE INDEX idx_mpc_sessions_created_by ON mpc_sessions(created_by);
|
||||||
|
|
||||||
-- Session Participants table
|
-- Session Participants table
|
||||||
CREATE TABLE participants (
|
CREATE TABLE participants (
|
||||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
session_id UUID NOT NULL REFERENCES mpc_sessions(id) ON DELETE CASCADE,
|
session_id UUID NOT NULL REFERENCES mpc_sessions(id) ON DELETE CASCADE,
|
||||||
party_id VARCHAR(255) NOT NULL,
|
party_id VARCHAR(255) NOT NULL,
|
||||||
party_index INTEGER NOT NULL,
|
party_index INTEGER NOT NULL,
|
||||||
status VARCHAR(20) NOT NULL,
|
status VARCHAR(20) NOT NULL,
|
||||||
device_type VARCHAR(50),
|
device_type VARCHAR(50),
|
||||||
device_id VARCHAR(255),
|
device_id VARCHAR(255),
|
||||||
platform VARCHAR(50),
|
platform VARCHAR(50),
|
||||||
app_version VARCHAR(50),
|
app_version VARCHAR(50),
|
||||||
public_key BYTEA, -- Party identity public key (for authentication)
|
public_key BYTEA, -- Party identity public key (for authentication)
|
||||||
joined_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
joined_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||||
completed_at TIMESTAMP,
|
completed_at TIMESTAMP,
|
||||||
CONSTRAINT chk_participant_status CHECK (status IN ('invited', 'joined', 'ready', 'completed', 'failed')),
|
CONSTRAINT chk_participant_status CHECK (status IN ('invited', 'joined', 'ready', 'completed', 'failed')),
|
||||||
UNIQUE(session_id, party_id),
|
UNIQUE(session_id, party_id),
|
||||||
UNIQUE(session_id, party_index)
|
UNIQUE(session_id, party_index)
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Indexes for participants
|
-- Indexes for participants
|
||||||
CREATE INDEX idx_participants_session_id ON participants(session_id);
|
CREATE INDEX idx_participants_session_id ON participants(session_id);
|
||||||
CREATE INDEX idx_participants_party_id ON participants(party_id);
|
CREATE INDEX idx_participants_party_id ON participants(party_id);
|
||||||
CREATE INDEX idx_participants_status ON participants(status);
|
CREATE INDEX idx_participants_status ON participants(status);
|
||||||
|
|
||||||
-- ============================================
|
-- ============================================
|
||||||
-- Message Router Schema
|
-- Message Router Schema
|
||||||
-- ============================================
|
-- ============================================
|
||||||
|
|
||||||
-- MPC Messages table (for offline message caching)
|
-- MPC Messages table (for offline message caching)
|
||||||
CREATE TABLE mpc_messages (
|
CREATE TABLE mpc_messages (
|
||||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
session_id UUID NOT NULL REFERENCES mpc_sessions(id) ON DELETE CASCADE,
|
session_id UUID NOT NULL REFERENCES mpc_sessions(id) ON DELETE CASCADE,
|
||||||
from_party VARCHAR(255) NOT NULL,
|
from_party VARCHAR(255) NOT NULL,
|
||||||
to_parties TEXT[], -- NULL means broadcast
|
to_parties TEXT[], -- NULL means broadcast
|
||||||
round_number INTEGER NOT NULL,
|
round_number INTEGER NOT NULL,
|
||||||
message_type VARCHAR(50) NOT NULL,
|
message_type VARCHAR(50) NOT NULL,
|
||||||
payload BYTEA NOT NULL, -- Encrypted MPC message
|
payload BYTEA NOT NULL, -- Encrypted MPC message
|
||||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||||
delivered_at TIMESTAMP,
|
delivered_at TIMESTAMP,
|
||||||
CONSTRAINT chk_round_number CHECK (round_number >= 0)
|
CONSTRAINT chk_round_number CHECK (round_number >= 0)
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Indexes for mpc_messages
|
-- Indexes for mpc_messages
|
||||||
CREATE INDEX idx_mpc_messages_session_id ON mpc_messages(session_id);
|
CREATE INDEX idx_mpc_messages_session_id ON mpc_messages(session_id);
|
||||||
CREATE INDEX idx_mpc_messages_to_parties ON mpc_messages USING GIN(to_parties);
|
CREATE INDEX idx_mpc_messages_to_parties ON mpc_messages USING GIN(to_parties);
|
||||||
CREATE INDEX idx_mpc_messages_delivered_at ON mpc_messages(delivered_at) WHERE delivered_at IS NULL;
|
CREATE INDEX idx_mpc_messages_delivered_at ON mpc_messages(delivered_at) WHERE delivered_at IS NULL;
|
||||||
CREATE INDEX idx_mpc_messages_created_at ON mpc_messages(created_at);
|
CREATE INDEX idx_mpc_messages_created_at ON mpc_messages(created_at);
|
||||||
CREATE INDEX idx_mpc_messages_round ON mpc_messages(session_id, round_number);
|
CREATE INDEX idx_mpc_messages_round ON mpc_messages(session_id, round_number);
|
||||||
|
|
||||||
-- ============================================
|
-- ============================================
|
||||||
-- Server Party Service Schema
|
-- Server Party Service Schema
|
||||||
-- ============================================
|
-- ============================================
|
||||||
|
|
||||||
-- Party Key Shares table (Server Party's own Share)
|
-- Party Key Shares table (Server Party's own Share)
|
||||||
CREATE TABLE party_key_shares (
|
CREATE TABLE party_key_shares (
|
||||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
party_id VARCHAR(255) NOT NULL,
|
party_id VARCHAR(255) NOT NULL,
|
||||||
party_index INTEGER NOT NULL,
|
party_index INTEGER NOT NULL,
|
||||||
session_id UUID NOT NULL, -- Keygen session ID
|
session_id UUID NOT NULL, -- Keygen session ID
|
||||||
threshold_n INTEGER NOT NULL,
|
threshold_n INTEGER NOT NULL,
|
||||||
threshold_t INTEGER NOT NULL,
|
threshold_t INTEGER NOT NULL,
|
||||||
share_data BYTEA NOT NULL, -- Encrypted tss-lib LocalPartySaveData
|
share_data BYTEA NOT NULL, -- Encrypted tss-lib LocalPartySaveData
|
||||||
public_key BYTEA NOT NULL, -- Group public key
|
public_key BYTEA NOT NULL, -- Group public key
|
||||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||||
last_used_at TIMESTAMP,
|
last_used_at TIMESTAMP,
|
||||||
CONSTRAINT chk_key_share_threshold CHECK (threshold_t <= threshold_n)
|
CONSTRAINT chk_key_share_threshold CHECK (threshold_t <= threshold_n)
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Indexes for party_key_shares
|
-- Indexes for party_key_shares
|
||||||
CREATE INDEX idx_party_key_shares_party_id ON party_key_shares(party_id);
|
CREATE INDEX idx_party_key_shares_party_id ON party_key_shares(party_id);
|
||||||
CREATE INDEX idx_party_key_shares_session_id ON party_key_shares(session_id);
|
CREATE INDEX idx_party_key_shares_session_id ON party_key_shares(session_id);
|
||||||
CREATE INDEX idx_party_key_shares_public_key ON party_key_shares(public_key);
|
CREATE INDEX idx_party_key_shares_public_key ON party_key_shares(public_key);
|
||||||
CREATE UNIQUE INDEX idx_party_key_shares_unique ON party_key_shares(party_id, session_id);
|
CREATE UNIQUE INDEX idx_party_key_shares_unique ON party_key_shares(party_id, session_id);
|
||||||
|
|
||||||
-- ============================================
|
-- ============================================
|
||||||
-- Account Service Schema
|
-- Account Service Schema
|
||||||
-- ============================================
|
-- ============================================
|
||||||
|
|
||||||
-- Accounts table
|
-- Accounts table
|
||||||
CREATE TABLE accounts (
|
CREATE TABLE accounts (
|
||||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
username VARCHAR(255) UNIQUE NOT NULL,
|
username VARCHAR(255) UNIQUE NOT NULL,
|
||||||
email VARCHAR(255) UNIQUE NOT NULL,
|
email VARCHAR(255) UNIQUE NOT NULL,
|
||||||
phone VARCHAR(50),
|
phone VARCHAR(50),
|
||||||
public_key BYTEA NOT NULL, -- MPC group public key
|
public_key BYTEA NOT NULL, -- MPC group public key
|
||||||
keygen_session_id UUID NOT NULL, -- Related Keygen session
|
keygen_session_id UUID NOT NULL, -- Related Keygen session
|
||||||
threshold_n INTEGER NOT NULL,
|
threshold_n INTEGER NOT NULL,
|
||||||
threshold_t INTEGER NOT NULL,
|
threshold_t INTEGER NOT NULL,
|
||||||
status VARCHAR(20) NOT NULL,
|
status VARCHAR(20) NOT NULL,
|
||||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||||
updated_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
updated_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||||
last_login_at TIMESTAMP,
|
last_login_at TIMESTAMP,
|
||||||
CONSTRAINT chk_account_status CHECK (status IN ('active', 'suspended', 'locked', 'recovering'))
|
CONSTRAINT chk_account_status CHECK (status IN ('active', 'suspended', 'locked', 'recovering'))
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Indexes for accounts
|
-- Indexes for accounts
|
||||||
CREATE INDEX idx_accounts_username ON accounts(username);
|
CREATE INDEX idx_accounts_username ON accounts(username);
|
||||||
CREATE INDEX idx_accounts_email ON accounts(email);
|
CREATE INDEX idx_accounts_email ON accounts(email);
|
||||||
CREATE INDEX idx_accounts_public_key ON accounts(public_key);
|
CREATE INDEX idx_accounts_public_key ON accounts(public_key);
|
||||||
CREATE INDEX idx_accounts_status ON accounts(status);
|
CREATE INDEX idx_accounts_status ON accounts(status);
|
||||||
|
|
||||||
-- Account Share Mapping table (records share locations, not share content)
|
-- Account Share Mapping table (records share locations, not share content)
|
||||||
CREATE TABLE account_shares (
|
CREATE TABLE account_shares (
|
||||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
account_id UUID NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
|
account_id UUID NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
|
||||||
share_type VARCHAR(20) NOT NULL, -- 'user_device', 'server', 'recovery'
|
share_type VARCHAR(20) NOT NULL, -- 'user_device', 'server', 'recovery'
|
||||||
party_id VARCHAR(255) NOT NULL,
|
party_id VARCHAR(255) NOT NULL,
|
||||||
party_index INTEGER NOT NULL,
|
party_index INTEGER NOT NULL,
|
||||||
device_type VARCHAR(50),
|
device_type VARCHAR(50),
|
||||||
device_id VARCHAR(255),
|
device_id VARCHAR(255),
|
||||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||||
last_used_at TIMESTAMP,
|
last_used_at TIMESTAMP,
|
||||||
is_active BOOLEAN DEFAULT TRUE,
|
is_active BOOLEAN DEFAULT TRUE,
|
||||||
CONSTRAINT chk_share_type CHECK (share_type IN ('user_device', 'server', 'recovery'))
|
CONSTRAINT chk_share_type CHECK (share_type IN ('user_device', 'server', 'recovery'))
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Indexes for account_shares
|
-- Indexes for account_shares
|
||||||
CREATE INDEX idx_account_shares_account_id ON account_shares(account_id);
|
CREATE INDEX idx_account_shares_account_id ON account_shares(account_id);
|
||||||
CREATE INDEX idx_account_shares_party_id ON account_shares(party_id);
|
CREATE INDEX idx_account_shares_party_id ON account_shares(party_id);
|
||||||
CREATE INDEX idx_account_shares_active ON account_shares(account_id, is_active) WHERE is_active = TRUE;
|
CREATE INDEX idx_account_shares_active ON account_shares(account_id, is_active) WHERE is_active = TRUE;
|
||||||
|
|
||||||
-- Account Recovery Sessions table
|
-- Account Recovery Sessions table
|
||||||
CREATE TABLE account_recovery_sessions (
|
CREATE TABLE account_recovery_sessions (
|
||||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
account_id UUID NOT NULL REFERENCES accounts(id),
|
account_id UUID NOT NULL REFERENCES accounts(id),
|
||||||
recovery_type VARCHAR(20) NOT NULL, -- 'device_lost', 'share_rotation'
|
recovery_type VARCHAR(20) NOT NULL, -- 'device_lost', 'share_rotation'
|
||||||
old_share_type VARCHAR(20),
|
old_share_type VARCHAR(20),
|
||||||
new_keygen_session_id UUID,
|
new_keygen_session_id UUID,
|
||||||
status VARCHAR(20) NOT NULL,
|
status VARCHAR(20) NOT NULL,
|
||||||
requested_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
requested_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||||
completed_at TIMESTAMP,
|
completed_at TIMESTAMP,
|
||||||
CONSTRAINT chk_recovery_status CHECK (status IN ('requested', 'in_progress', 'completed', 'failed'))
|
CONSTRAINT chk_recovery_status CHECK (status IN ('requested', 'in_progress', 'completed', 'failed'))
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Indexes for account_recovery_sessions
|
-- Indexes for account_recovery_sessions
|
||||||
CREATE INDEX idx_account_recovery_account_id ON account_recovery_sessions(account_id);
|
CREATE INDEX idx_account_recovery_account_id ON account_recovery_sessions(account_id);
|
||||||
CREATE INDEX idx_account_recovery_status ON account_recovery_sessions(status);
|
CREATE INDEX idx_account_recovery_status ON account_recovery_sessions(status);
|
||||||
|
|
||||||
-- ============================================
|
-- ============================================
|
||||||
-- Audit Service Schema
|
-- Audit Service Schema
|
||||||
-- ============================================
|
-- ============================================
|
||||||
|
|
||||||
-- Audit Workflows table
|
-- Audit Workflows table
|
||||||
CREATE TABLE audit_workflows (
|
CREATE TABLE audit_workflows (
|
||||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
workflow_name VARCHAR(255) NOT NULL,
|
workflow_name VARCHAR(255) NOT NULL,
|
||||||
workflow_type VARCHAR(50) NOT NULL,
|
workflow_type VARCHAR(50) NOT NULL,
|
||||||
data_hash BYTEA NOT NULL,
|
data_hash BYTEA NOT NULL,
|
||||||
threshold_n INTEGER NOT NULL,
|
threshold_n INTEGER NOT NULL,
|
||||||
threshold_t INTEGER NOT NULL,
|
threshold_t INTEGER NOT NULL,
|
||||||
sign_session_id UUID, -- Related signing session
|
sign_session_id UUID, -- Related signing session
|
||||||
signature BYTEA,
|
signature BYTEA,
|
||||||
status VARCHAR(20) NOT NULL,
|
status VARCHAR(20) NOT NULL,
|
||||||
created_by VARCHAR(255) NOT NULL,
|
created_by VARCHAR(255) NOT NULL,
|
||||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||||
updated_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
updated_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||||
expires_at TIMESTAMP,
|
expires_at TIMESTAMP,
|
||||||
completed_at TIMESTAMP,
|
completed_at TIMESTAMP,
|
||||||
metadata JSONB,
|
metadata JSONB,
|
||||||
CONSTRAINT chk_audit_workflow_status CHECK (status IN ('pending', 'in_progress', 'approved', 'rejected', 'expired'))
|
CONSTRAINT chk_audit_workflow_status CHECK (status IN ('pending', 'in_progress', 'approved', 'rejected', 'expired'))
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Indexes for audit_workflows
|
-- Indexes for audit_workflows
|
||||||
CREATE INDEX idx_audit_workflows_status ON audit_workflows(status);
|
CREATE INDEX idx_audit_workflows_status ON audit_workflows(status);
|
||||||
CREATE INDEX idx_audit_workflows_created_at ON audit_workflows(created_at);
|
CREATE INDEX idx_audit_workflows_created_at ON audit_workflows(created_at);
|
||||||
CREATE INDEX idx_audit_workflows_workflow_type ON audit_workflows(workflow_type);
|
CREATE INDEX idx_audit_workflows_workflow_type ON audit_workflows(workflow_type);
|
||||||
|
|
||||||
-- Audit Approvers table
|
-- Audit Approvers table
|
||||||
CREATE TABLE audit_approvers (
|
CREATE TABLE audit_approvers (
|
||||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
workflow_id UUID NOT NULL REFERENCES audit_workflows(id) ON DELETE CASCADE,
|
workflow_id UUID NOT NULL REFERENCES audit_workflows(id) ON DELETE CASCADE,
|
||||||
approver_id VARCHAR(255) NOT NULL,
|
approver_id VARCHAR(255) NOT NULL,
|
||||||
party_id VARCHAR(255) NOT NULL,
|
party_id VARCHAR(255) NOT NULL,
|
||||||
party_index INTEGER NOT NULL,
|
party_index INTEGER NOT NULL,
|
||||||
status VARCHAR(20) NOT NULL,
|
status VARCHAR(20) NOT NULL,
|
||||||
approved_at TIMESTAMP,
|
approved_at TIMESTAMP,
|
||||||
comments TEXT,
|
comments TEXT,
|
||||||
CONSTRAINT chk_approver_status CHECK (status IN ('pending', 'approved', 'rejected')),
|
CONSTRAINT chk_approver_status CHECK (status IN ('pending', 'approved', 'rejected')),
|
||||||
UNIQUE(workflow_id, approver_id)
|
UNIQUE(workflow_id, approver_id)
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Indexes for audit_approvers
|
-- Indexes for audit_approvers
|
||||||
CREATE INDEX idx_audit_approvers_workflow_id ON audit_approvers(workflow_id);
|
CREATE INDEX idx_audit_approvers_workflow_id ON audit_approvers(workflow_id);
|
||||||
CREATE INDEX idx_audit_approvers_approver_id ON audit_approvers(approver_id);
|
CREATE INDEX idx_audit_approvers_approver_id ON audit_approvers(approver_id);
|
||||||
CREATE INDEX idx_audit_approvers_status ON audit_approvers(status);
|
CREATE INDEX idx_audit_approvers_status ON audit_approvers(status);
|
||||||
|
|
||||||
-- ============================================
|
-- ============================================
|
||||||
-- Shared Audit Logs Schema
|
-- Shared Audit Logs Schema
|
||||||
-- ============================================
|
-- ============================================
|
||||||
|
|
||||||
-- Audit Logs table (shared across all services)
|
-- Audit Logs table (shared across all services)
|
||||||
CREATE TABLE audit_logs (
|
CREATE TABLE audit_logs (
|
||||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
service_name VARCHAR(100) NOT NULL,
|
service_name VARCHAR(100) NOT NULL,
|
||||||
action_type VARCHAR(100) NOT NULL,
|
action_type VARCHAR(100) NOT NULL,
|
||||||
user_id VARCHAR(255),
|
user_id VARCHAR(255),
|
||||||
resource_type VARCHAR(100),
|
resource_type VARCHAR(100),
|
||||||
resource_id VARCHAR(255),
|
resource_id VARCHAR(255),
|
||||||
session_id UUID,
|
session_id UUID,
|
||||||
ip_address INET,
|
ip_address INET,
|
||||||
user_agent TEXT,
|
user_agent TEXT,
|
||||||
request_data JSONB,
|
request_data JSONB,
|
||||||
response_data JSONB,
|
response_data JSONB,
|
||||||
status VARCHAR(20) NOT NULL,
|
status VARCHAR(20) NOT NULL,
|
||||||
error_message TEXT,
|
error_message TEXT,
|
||||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||||
CONSTRAINT chk_audit_status CHECK (status IN ('success', 'failure', 'pending'))
|
CONSTRAINT chk_audit_status CHECK (status IN ('success', 'failure', 'pending'))
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Indexes for audit_logs
|
-- Indexes for audit_logs
|
||||||
CREATE INDEX idx_audit_logs_created_at ON audit_logs(created_at);
|
CREATE INDEX idx_audit_logs_created_at ON audit_logs(created_at);
|
||||||
CREATE INDEX idx_audit_logs_user_id ON audit_logs(user_id);
|
CREATE INDEX idx_audit_logs_user_id ON audit_logs(user_id);
|
||||||
CREATE INDEX idx_audit_logs_session_id ON audit_logs(session_id);
|
CREATE INDEX idx_audit_logs_session_id ON audit_logs(session_id);
|
||||||
CREATE INDEX idx_audit_logs_action_type ON audit_logs(action_type);
|
CREATE INDEX idx_audit_logs_action_type ON audit_logs(action_type);
|
||||||
CREATE INDEX idx_audit_logs_service_name ON audit_logs(service_name);
|
CREATE INDEX idx_audit_logs_service_name ON audit_logs(service_name);
|
||||||
|
|
||||||
-- Partitioning for audit_logs (if needed for large scale)
|
-- Partitioning for audit_logs (if needed for large scale)
|
||||||
-- CREATE TABLE audit_logs_y2024m01 PARTITION OF audit_logs
|
-- CREATE TABLE audit_logs_y2024m01 PARTITION OF audit_logs
|
||||||
-- FOR VALUES FROM ('2024-01-01') TO ('2024-02-01');
|
-- FOR VALUES FROM ('2024-01-01') TO ('2024-02-01');
|
||||||
|
|
||||||
-- ============================================
|
-- ============================================
|
||||||
-- Helper Functions
|
-- Helper Functions
|
||||||
-- ============================================
|
-- ============================================
|
||||||
|
|
||||||
-- Function to update updated_at timestamp
|
-- Function to update updated_at timestamp
|
||||||
CREATE OR REPLACE FUNCTION update_updated_at_column()
|
CREATE OR REPLACE FUNCTION update_updated_at_column()
|
||||||
RETURNS TRIGGER AS $$
|
RETURNS TRIGGER AS $$
|
||||||
BEGIN
|
BEGIN
|
||||||
NEW.updated_at = NOW();
|
NEW.updated_at = NOW();
|
||||||
RETURN NEW;
|
RETURN NEW;
|
||||||
END;
|
END;
|
||||||
$$ language 'plpgsql';
|
$$ language 'plpgsql';
|
||||||
|
|
||||||
-- Triggers for auto-updating updated_at
|
-- Triggers for auto-updating updated_at
|
||||||
CREATE TRIGGER update_mpc_sessions_updated_at
|
CREATE TRIGGER update_mpc_sessions_updated_at
|
||||||
BEFORE UPDATE ON mpc_sessions
|
BEFORE UPDATE ON mpc_sessions
|
||||||
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
|
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
|
||||||
|
|
||||||
CREATE TRIGGER update_accounts_updated_at
|
CREATE TRIGGER update_accounts_updated_at
|
||||||
BEFORE UPDATE ON accounts
|
BEFORE UPDATE ON accounts
|
||||||
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
|
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
|
||||||
|
|
||||||
CREATE TRIGGER update_audit_workflows_updated_at
|
CREATE TRIGGER update_audit_workflows_updated_at
|
||||||
BEFORE UPDATE ON audit_workflows
|
BEFORE UPDATE ON audit_workflows
|
||||||
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
|
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
|
||||||
|
|
||||||
-- Function to cleanup expired sessions
|
-- Function to cleanup expired sessions
|
||||||
CREATE OR REPLACE FUNCTION cleanup_expired_sessions()
|
CREATE OR REPLACE FUNCTION cleanup_expired_sessions()
|
||||||
RETURNS INTEGER AS $$
|
RETURNS INTEGER AS $$
|
||||||
DECLARE
|
DECLARE
|
||||||
deleted_count INTEGER;
|
deleted_count INTEGER;
|
||||||
BEGIN
|
BEGIN
|
||||||
UPDATE mpc_sessions
|
UPDATE mpc_sessions
|
||||||
SET status = 'expired', updated_at = NOW()
|
SET status = 'expired', updated_at = NOW()
|
||||||
WHERE expires_at < NOW()
|
WHERE expires_at < NOW()
|
||||||
AND status IN ('created', 'in_progress');
|
AND status IN ('created', 'in_progress');
|
||||||
|
|
||||||
GET DIAGNOSTICS deleted_count = ROW_COUNT;
|
GET DIAGNOSTICS deleted_count = ROW_COUNT;
|
||||||
RETURN deleted_count;
|
RETURN deleted_count;
|
||||||
END;
|
END;
|
||||||
$$ language 'plpgsql';
|
$$ language 'plpgsql';
|
||||||
|
|
||||||
-- Function to cleanup old messages
|
-- Function to cleanup old messages
|
||||||
CREATE OR REPLACE FUNCTION cleanup_old_messages(retention_hours INTEGER DEFAULT 24)
|
CREATE OR REPLACE FUNCTION cleanup_old_messages(retention_hours INTEGER DEFAULT 24)
|
||||||
RETURNS INTEGER AS $$
|
RETURNS INTEGER AS $$
|
||||||
DECLARE
|
DECLARE
|
||||||
deleted_count INTEGER;
|
deleted_count INTEGER;
|
||||||
BEGIN
|
BEGIN
|
||||||
DELETE FROM mpc_messages
|
DELETE FROM mpc_messages
|
||||||
WHERE created_at < NOW() - (retention_hours || ' hours')::INTERVAL;
|
WHERE created_at < NOW() - (retention_hours || ' hours')::INTERVAL;
|
||||||
|
|
||||||
GET DIAGNOSTICS deleted_count = ROW_COUNT;
|
GET DIAGNOSTICS deleted_count = ROW_COUNT;
|
||||||
RETURN deleted_count;
|
RETURN deleted_count;
|
||||||
END;
|
END;
|
||||||
$$ language 'plpgsql';
|
$$ language 'plpgsql';
|
||||||
|
|
||||||
-- Comments
|
-- Comments
|
||||||
COMMENT ON TABLE mpc_sessions IS 'MPC session management - Coordinator does not participate in MPC computation';
|
COMMENT ON TABLE mpc_sessions IS 'MPC session management - Coordinator does not participate in MPC computation';
|
||||||
COMMENT ON TABLE participants IS 'Session participants - tracks join status of each party';
|
COMMENT ON TABLE participants IS 'Session participants - tracks join status of each party';
|
||||||
COMMENT ON TABLE mpc_messages IS 'MPC protocol messages - encrypted, router does not decrypt';
|
COMMENT ON TABLE mpc_messages IS 'MPC protocol messages - encrypted, router does not decrypt';
|
||||||
COMMENT ON TABLE party_key_shares IS 'Server party key shares - encrypted storage of tss-lib data';
|
COMMENT ON TABLE party_key_shares IS 'Server party key shares - encrypted storage of tss-lib data';
|
||||||
COMMENT ON TABLE accounts IS 'User accounts with MPC-based authentication';
|
COMMENT ON TABLE accounts IS 'User accounts with MPC-based authentication';
|
||||||
COMMENT ON TABLE audit_logs IS 'Comprehensive audit trail for all operations';
|
COMMENT ON TABLE audit_logs IS 'Comprehensive audit trail for all operations';
|
||||||
|
|
|
||||||
|
|
@ -1,227 +1,227 @@
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config holds all configuration for the MPC system
|
// Config holds all configuration for the MPC system
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig `mapstructure:"server"`
|
Server ServerConfig `mapstructure:"server"`
|
||||||
Database DatabaseConfig `mapstructure:"database"`
|
Database DatabaseConfig `mapstructure:"database"`
|
||||||
Redis RedisConfig `mapstructure:"redis"`
|
Redis RedisConfig `mapstructure:"redis"`
|
||||||
RabbitMQ RabbitMQConfig `mapstructure:"rabbitmq"`
|
RabbitMQ RabbitMQConfig `mapstructure:"rabbitmq"`
|
||||||
Consul ConsulConfig `mapstructure:"consul"`
|
Consul ConsulConfig `mapstructure:"consul"`
|
||||||
JWT JWTConfig `mapstructure:"jwt"`
|
JWT JWTConfig `mapstructure:"jwt"`
|
||||||
MPC MPCConfig `mapstructure:"mpc"`
|
MPC MPCConfig `mapstructure:"mpc"`
|
||||||
Logger LoggerConfig `mapstructure:"logger"`
|
Logger LoggerConfig `mapstructure:"logger"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerConfig holds server-related configuration
|
// ServerConfig holds server-related configuration
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
GRPCPort int `mapstructure:"grpc_port"`
|
GRPCPort int `mapstructure:"grpc_port"`
|
||||||
HTTPPort int `mapstructure:"http_port"`
|
HTTPPort int `mapstructure:"http_port"`
|
||||||
Environment string `mapstructure:"environment"`
|
Environment string `mapstructure:"environment"`
|
||||||
Timeout time.Duration `mapstructure:"timeout"`
|
Timeout time.Duration `mapstructure:"timeout"`
|
||||||
TLSEnabled bool `mapstructure:"tls_enabled"`
|
TLSEnabled bool `mapstructure:"tls_enabled"`
|
||||||
TLSCertFile string `mapstructure:"tls_cert_file"`
|
TLSCertFile string `mapstructure:"tls_cert_file"`
|
||||||
TLSKeyFile string `mapstructure:"tls_key_file"`
|
TLSKeyFile string `mapstructure:"tls_key_file"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// DatabaseConfig holds database configuration
|
// DatabaseConfig holds database configuration
|
||||||
type DatabaseConfig struct {
|
type DatabaseConfig struct {
|
||||||
Host string `mapstructure:"host"`
|
Host string `mapstructure:"host"`
|
||||||
Port int `mapstructure:"port"`
|
Port int `mapstructure:"port"`
|
||||||
User string `mapstructure:"user"`
|
User string `mapstructure:"user"`
|
||||||
Password string `mapstructure:"password"`
|
Password string `mapstructure:"password"`
|
||||||
DBName string `mapstructure:"dbname"`
|
DBName string `mapstructure:"dbname"`
|
||||||
SSLMode string `mapstructure:"sslmode"`
|
SSLMode string `mapstructure:"sslmode"`
|
||||||
MaxOpenConns int `mapstructure:"max_open_conns"`
|
MaxOpenConns int `mapstructure:"max_open_conns"`
|
||||||
MaxIdleConns int `mapstructure:"max_idle_conns"`
|
MaxIdleConns int `mapstructure:"max_idle_conns"`
|
||||||
ConnMaxLife time.Duration `mapstructure:"conn_max_life"`
|
ConnMaxLife time.Duration `mapstructure:"conn_max_life"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// DSN returns the database connection string
|
// DSN returns the database connection string
|
||||||
func (c *DatabaseConfig) DSN() string {
|
func (c *DatabaseConfig) DSN() string {
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||||
c.Host, c.Port, c.User, c.Password, c.DBName, c.SSLMode,
|
c.Host, c.Port, c.User, c.Password, c.DBName, c.SSLMode,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RedisConfig holds Redis configuration
|
// RedisConfig holds Redis configuration
|
||||||
type RedisConfig struct {
|
type RedisConfig struct {
|
||||||
Host string `mapstructure:"host"`
|
Host string `mapstructure:"host"`
|
||||||
Port int `mapstructure:"port"`
|
Port int `mapstructure:"port"`
|
||||||
Password string `mapstructure:"password"`
|
Password string `mapstructure:"password"`
|
||||||
DB int `mapstructure:"db"`
|
DB int `mapstructure:"db"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Addr returns the Redis address
|
// Addr returns the Redis address
|
||||||
func (c *RedisConfig) Addr() string {
|
func (c *RedisConfig) Addr() string {
|
||||||
return fmt.Sprintf("%s:%d", c.Host, c.Port)
|
return fmt.Sprintf("%s:%d", c.Host, c.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RabbitMQConfig holds RabbitMQ configuration
|
// RabbitMQConfig holds RabbitMQ configuration
|
||||||
type RabbitMQConfig struct {
|
type RabbitMQConfig struct {
|
||||||
Host string `mapstructure:"host"`
|
Host string `mapstructure:"host"`
|
||||||
Port int `mapstructure:"port"`
|
Port int `mapstructure:"port"`
|
||||||
User string `mapstructure:"user"`
|
User string `mapstructure:"user"`
|
||||||
Password string `mapstructure:"password"`
|
Password string `mapstructure:"password"`
|
||||||
VHost string `mapstructure:"vhost"`
|
VHost string `mapstructure:"vhost"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// URL returns the RabbitMQ connection URL
|
// URL returns the RabbitMQ connection URL
|
||||||
func (c *RabbitMQConfig) URL() string {
|
func (c *RabbitMQConfig) URL() string {
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
"amqp://%s:%s@%s:%d/%s",
|
"amqp://%s:%s@%s:%d/%s",
|
||||||
c.User, c.Password, c.Host, c.Port, c.VHost,
|
c.User, c.Password, c.Host, c.Port, c.VHost,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConsulConfig holds Consul configuration
|
// ConsulConfig holds Consul configuration
|
||||||
type ConsulConfig struct {
|
type ConsulConfig struct {
|
||||||
Host string `mapstructure:"host"`
|
Host string `mapstructure:"host"`
|
||||||
Port int `mapstructure:"port"`
|
Port int `mapstructure:"port"`
|
||||||
ServiceID string `mapstructure:"service_id"`
|
ServiceID string `mapstructure:"service_id"`
|
||||||
Tags []string `mapstructure:"tags"`
|
Tags []string `mapstructure:"tags"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Addr returns the Consul address
|
// Addr returns the Consul address
|
||||||
func (c *ConsulConfig) Addr() string {
|
func (c *ConsulConfig) Addr() string {
|
||||||
return fmt.Sprintf("%s:%d", c.Host, c.Port)
|
return fmt.Sprintf("%s:%d", c.Host, c.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWTConfig holds JWT configuration
|
// JWTConfig holds JWT configuration
|
||||||
type JWTConfig struct {
|
type JWTConfig struct {
|
||||||
SecretKey string `mapstructure:"secret_key"`
|
SecretKey string `mapstructure:"secret_key"`
|
||||||
Issuer string `mapstructure:"issuer"`
|
Issuer string `mapstructure:"issuer"`
|
||||||
TokenExpiry time.Duration `mapstructure:"token_expiry"`
|
TokenExpiry time.Duration `mapstructure:"token_expiry"`
|
||||||
RefreshExpiry time.Duration `mapstructure:"refresh_expiry"`
|
RefreshExpiry time.Duration `mapstructure:"refresh_expiry"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// MPCConfig holds MPC-specific configuration
|
// MPCConfig holds MPC-specific configuration
|
||||||
type MPCConfig struct {
|
type MPCConfig struct {
|
||||||
DefaultThresholdN int `mapstructure:"default_threshold_n"`
|
DefaultThresholdN int `mapstructure:"default_threshold_n"`
|
||||||
DefaultThresholdT int `mapstructure:"default_threshold_t"`
|
DefaultThresholdT int `mapstructure:"default_threshold_t"`
|
||||||
SessionTimeout time.Duration `mapstructure:"session_timeout"`
|
SessionTimeout time.Duration `mapstructure:"session_timeout"`
|
||||||
MessageTimeout time.Duration `mapstructure:"message_timeout"`
|
MessageTimeout time.Duration `mapstructure:"message_timeout"`
|
||||||
KeygenTimeout time.Duration `mapstructure:"keygen_timeout"`
|
KeygenTimeout time.Duration `mapstructure:"keygen_timeout"`
|
||||||
SigningTimeout time.Duration `mapstructure:"signing_timeout"`
|
SigningTimeout time.Duration `mapstructure:"signing_timeout"`
|
||||||
MaxParties int `mapstructure:"max_parties"`
|
MaxParties int `mapstructure:"max_parties"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoggerConfig holds logger configuration
|
// LoggerConfig holds logger configuration
|
||||||
type LoggerConfig struct {
|
type LoggerConfig struct {
|
||||||
Level string `mapstructure:"level"`
|
Level string `mapstructure:"level"`
|
||||||
Encoding string `mapstructure:"encoding"`
|
Encoding string `mapstructure:"encoding"`
|
||||||
OutputPath string `mapstructure:"output_path"`
|
OutputPath string `mapstructure:"output_path"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load loads configuration from file and environment variables
|
// Load loads configuration from file and environment variables
|
||||||
func Load(configPath string) (*Config, error) {
|
func Load(configPath string) (*Config, error) {
|
||||||
v := viper.New()
|
v := viper.New()
|
||||||
|
|
||||||
// Set default values
|
// Set default values
|
||||||
setDefaults(v)
|
setDefaults(v)
|
||||||
|
|
||||||
// Read config file
|
// Read config file
|
||||||
if configPath != "" {
|
if configPath != "" {
|
||||||
v.SetConfigFile(configPath)
|
v.SetConfigFile(configPath)
|
||||||
} else {
|
} else {
|
||||||
v.SetConfigName("config")
|
v.SetConfigName("config")
|
||||||
v.SetConfigType("yaml")
|
v.SetConfigType("yaml")
|
||||||
v.AddConfigPath(".")
|
v.AddConfigPath(".")
|
||||||
v.AddConfigPath("./config")
|
v.AddConfigPath("./config")
|
||||||
v.AddConfigPath("/etc/mpc-system/")
|
v.AddConfigPath("/etc/mpc-system/")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read environment variables
|
// Read environment variables
|
||||||
v.SetEnvPrefix("MPC")
|
v.SetEnvPrefix("MPC")
|
||||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||||
v.AutomaticEnv()
|
v.AutomaticEnv()
|
||||||
|
|
||||||
// Read config file (if exists)
|
// Read config file (if exists)
|
||||||
if err := v.ReadInConfig(); err != nil {
|
if err := v.ReadInConfig(); err != nil {
|
||||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||||
}
|
}
|
||||||
// Config file not found is not an error, we'll use defaults + env vars
|
// Config file not found is not an error, we'll use defaults + env vars
|
||||||
}
|
}
|
||||||
|
|
||||||
var config Config
|
var config Config
|
||||||
if err := v.Unmarshal(&config); err != nil {
|
if err := v.Unmarshal(&config); err != nil {
|
||||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &config, nil
|
return &config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// setDefaults sets default configuration values
|
// setDefaults sets default configuration values
|
||||||
func setDefaults(v *viper.Viper) {
|
func setDefaults(v *viper.Viper) {
|
||||||
// Server defaults
|
// Server defaults
|
||||||
v.SetDefault("server.grpc_port", 50051)
|
v.SetDefault("server.grpc_port", 50051)
|
||||||
v.SetDefault("server.http_port", 8080)
|
v.SetDefault("server.http_port", 8080)
|
||||||
v.SetDefault("server.environment", "development")
|
v.SetDefault("server.environment", "development")
|
||||||
v.SetDefault("server.timeout", "30s")
|
v.SetDefault("server.timeout", "30s")
|
||||||
v.SetDefault("server.tls_enabled", false)
|
v.SetDefault("server.tls_enabled", false)
|
||||||
|
|
||||||
// Database defaults
|
// Database defaults
|
||||||
v.SetDefault("database.host", "localhost")
|
v.SetDefault("database.host", "localhost")
|
||||||
v.SetDefault("database.port", 5432)
|
v.SetDefault("database.port", 5432)
|
||||||
v.SetDefault("database.user", "mpc_user")
|
v.SetDefault("database.user", "mpc_user")
|
||||||
v.SetDefault("database.password", "")
|
v.SetDefault("database.password", "")
|
||||||
v.SetDefault("database.dbname", "mpc_system")
|
v.SetDefault("database.dbname", "mpc_system")
|
||||||
v.SetDefault("database.sslmode", "disable")
|
v.SetDefault("database.sslmode", "disable")
|
||||||
v.SetDefault("database.max_open_conns", 25)
|
v.SetDefault("database.max_open_conns", 25)
|
||||||
v.SetDefault("database.max_idle_conns", 5)
|
v.SetDefault("database.max_idle_conns", 5)
|
||||||
v.SetDefault("database.conn_max_life", "5m")
|
v.SetDefault("database.conn_max_life", "5m")
|
||||||
|
|
||||||
// Redis defaults
|
// Redis defaults
|
||||||
v.SetDefault("redis.host", "localhost")
|
v.SetDefault("redis.host", "localhost")
|
||||||
v.SetDefault("redis.port", 6379)
|
v.SetDefault("redis.port", 6379)
|
||||||
v.SetDefault("redis.password", "")
|
v.SetDefault("redis.password", "")
|
||||||
v.SetDefault("redis.db", 0)
|
v.SetDefault("redis.db", 0)
|
||||||
|
|
||||||
// RabbitMQ defaults
|
// RabbitMQ defaults
|
||||||
v.SetDefault("rabbitmq.host", "localhost")
|
v.SetDefault("rabbitmq.host", "localhost")
|
||||||
v.SetDefault("rabbitmq.port", 5672)
|
v.SetDefault("rabbitmq.port", 5672)
|
||||||
v.SetDefault("rabbitmq.user", "guest")
|
v.SetDefault("rabbitmq.user", "guest")
|
||||||
v.SetDefault("rabbitmq.password", "guest")
|
v.SetDefault("rabbitmq.password", "guest")
|
||||||
v.SetDefault("rabbitmq.vhost", "/")
|
v.SetDefault("rabbitmq.vhost", "/")
|
||||||
|
|
||||||
// Consul defaults
|
// Consul defaults
|
||||||
v.SetDefault("consul.host", "localhost")
|
v.SetDefault("consul.host", "localhost")
|
||||||
v.SetDefault("consul.port", 8500)
|
v.SetDefault("consul.port", 8500)
|
||||||
|
|
||||||
// JWT defaults
|
// JWT defaults
|
||||||
v.SetDefault("jwt.issuer", "mpc-system")
|
v.SetDefault("jwt.issuer", "mpc-system")
|
||||||
v.SetDefault("jwt.token_expiry", "15m")
|
v.SetDefault("jwt.token_expiry", "15m")
|
||||||
v.SetDefault("jwt.refresh_expiry", "24h")
|
v.SetDefault("jwt.refresh_expiry", "24h")
|
||||||
|
|
||||||
// MPC defaults
|
// MPC defaults
|
||||||
v.SetDefault("mpc.default_threshold_n", 3)
|
v.SetDefault("mpc.default_threshold_n", 3)
|
||||||
v.SetDefault("mpc.default_threshold_t", 2)
|
v.SetDefault("mpc.default_threshold_t", 2)
|
||||||
v.SetDefault("mpc.session_timeout", "10m")
|
v.SetDefault("mpc.session_timeout", "10m")
|
||||||
v.SetDefault("mpc.message_timeout", "30s")
|
v.SetDefault("mpc.message_timeout", "30s")
|
||||||
v.SetDefault("mpc.keygen_timeout", "10m")
|
v.SetDefault("mpc.keygen_timeout", "10m")
|
||||||
v.SetDefault("mpc.signing_timeout", "5m")
|
v.SetDefault("mpc.signing_timeout", "5m")
|
||||||
v.SetDefault("mpc.max_parties", 10)
|
v.SetDefault("mpc.max_parties", 10)
|
||||||
|
|
||||||
// Logger defaults
|
// Logger defaults
|
||||||
v.SetDefault("logger.level", "info")
|
v.SetDefault("logger.level", "info")
|
||||||
v.SetDefault("logger.encoding", "json")
|
v.SetDefault("logger.encoding", "json")
|
||||||
v.SetDefault("logger.output_path", "stdout")
|
v.SetDefault("logger.output_path", "stdout")
|
||||||
}
|
}
|
||||||
|
|
||||||
// MustLoad loads configuration and panics on error
|
// MustLoad loads configuration and panics on error
|
||||||
func MustLoad(configPath string) *Config {
|
func MustLoad(configPath string) *Config {
|
||||||
cfg, err := Load(configPath)
|
cfg, err := Load(configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("failed to load config: %v", err))
|
panic(fmt.Sprintf("failed to load config: %v", err))
|
||||||
}
|
}
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,374 +1,374 @@
|
||||||
package crypto
|
package crypto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/aes"
|
"crypto/aes"
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
|
||||||
"golang.org/x/crypto/hkdf"
|
"golang.org/x/crypto/hkdf"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidKeySize = errors.New("invalid key size")
|
ErrInvalidKeySize = errors.New("invalid key size")
|
||||||
ErrInvalidCipherText = errors.New("invalid ciphertext")
|
ErrInvalidCipherText = errors.New("invalid ciphertext")
|
||||||
ErrEncryptionFailed = errors.New("encryption failed")
|
ErrEncryptionFailed = errors.New("encryption failed")
|
||||||
ErrDecryptionFailed = errors.New("decryption failed")
|
ErrDecryptionFailed = errors.New("decryption failed")
|
||||||
ErrInvalidPublicKey = errors.New("invalid public key")
|
ErrInvalidPublicKey = errors.New("invalid public key")
|
||||||
ErrInvalidSignature = errors.New("invalid signature")
|
ErrInvalidSignature = errors.New("invalid signature")
|
||||||
)
|
)
|
||||||
|
|
||||||
// CryptoService provides cryptographic operations
|
// CryptoService provides cryptographic operations
|
||||||
type CryptoService struct {
|
type CryptoService struct {
|
||||||
masterKey []byte
|
masterKey []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCryptoService creates a new crypto service
|
// NewCryptoService creates a new crypto service
|
||||||
func NewCryptoService(masterKey []byte) (*CryptoService, error) {
|
func NewCryptoService(masterKey []byte) (*CryptoService, error) {
|
||||||
if len(masterKey) != 32 {
|
if len(masterKey) != 32 {
|
||||||
return nil, ErrInvalidKeySize
|
return nil, ErrInvalidKeySize
|
||||||
}
|
}
|
||||||
return &CryptoService{masterKey: masterKey}, nil
|
return &CryptoService{masterKey: masterKey}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateRandomBytes generates random bytes
|
// GenerateRandomBytes generates random bytes
|
||||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||||
b := make([]byte, n)
|
b := make([]byte, n)
|
||||||
_, err := rand.Read(b)
|
_, err := rand.Read(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateRandomHex generates a random hex string
|
// GenerateRandomHex generates a random hex string
|
||||||
func GenerateRandomHex(n int) (string, error) {
|
func GenerateRandomHex(n int) (string, error) {
|
||||||
bytes, err := GenerateRandomBytes(n)
|
bytes, err := GenerateRandomBytes(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return hex.EncodeToString(bytes), nil
|
return hex.EncodeToString(bytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeriveKey derives a key from the master key using HKDF
|
// DeriveKey derives a key from the master key using HKDF
|
||||||
func (c *CryptoService) DeriveKey(context string, length int) ([]byte, error) {
|
func (c *CryptoService) DeriveKey(context string, length int) ([]byte, error) {
|
||||||
hkdfReader := hkdf.New(sha256.New, c.masterKey, nil, []byte(context))
|
hkdfReader := hkdf.New(sha256.New, c.masterKey, nil, []byte(context))
|
||||||
key := make([]byte, length)
|
key := make([]byte, length)
|
||||||
if _, err := io.ReadFull(hkdfReader, key); err != nil {
|
if _, err := io.ReadFull(hkdfReader, key); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return key, nil
|
return key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// EncryptShare encrypts a key share using AES-256-GCM
|
// EncryptShare encrypts a key share using AES-256-GCM
|
||||||
func (c *CryptoService) EncryptShare(shareData []byte, partyID string) ([]byte, error) {
|
func (c *CryptoService) EncryptShare(shareData []byte, partyID string) ([]byte, error) {
|
||||||
// Derive a unique key for this party
|
// Derive a unique key for this party
|
||||||
key, err := c.DeriveKey("share_encryption:"+partyID, 32)
|
key, err := c.DeriveKey("share_encryption:"+partyID, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
block, err := aes.NewCipher(key)
|
block, err := aes.NewCipher(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
aesGCM, err := cipher.NewGCM(block)
|
aesGCM, err := cipher.NewGCM(block)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nonce := make([]byte, aesGCM.NonceSize())
|
nonce := make([]byte, aesGCM.NonceSize())
|
||||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encrypt and prepend nonce
|
// Encrypt and prepend nonce
|
||||||
ciphertext := aesGCM.Seal(nonce, nonce, shareData, []byte(partyID))
|
ciphertext := aesGCM.Seal(nonce, nonce, shareData, []byte(partyID))
|
||||||
return ciphertext, nil
|
return ciphertext, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecryptShare decrypts a key share
|
// DecryptShare decrypts a key share
|
||||||
func (c *CryptoService) DecryptShare(encryptedData []byte, partyID string) ([]byte, error) {
|
func (c *CryptoService) DecryptShare(encryptedData []byte, partyID string) ([]byte, error) {
|
||||||
// Derive the same key used for encryption
|
// Derive the same key used for encryption
|
||||||
key, err := c.DeriveKey("share_encryption:"+partyID, 32)
|
key, err := c.DeriveKey("share_encryption:"+partyID, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
block, err := aes.NewCipher(key)
|
block, err := aes.NewCipher(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
aesGCM, err := cipher.NewGCM(block)
|
aesGCM, err := cipher.NewGCM(block)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nonceSize := aesGCM.NonceSize()
|
nonceSize := aesGCM.NonceSize()
|
||||||
if len(encryptedData) < nonceSize {
|
if len(encryptedData) < nonceSize {
|
||||||
return nil, ErrInvalidCipherText
|
return nil, ErrInvalidCipherText
|
||||||
}
|
}
|
||||||
|
|
||||||
nonce, ciphertext := encryptedData[:nonceSize], encryptedData[nonceSize:]
|
nonce, ciphertext := encryptedData[:nonceSize], encryptedData[nonceSize:]
|
||||||
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, []byte(partyID))
|
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, []byte(partyID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrDecryptionFailed
|
return nil, ErrDecryptionFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
return plaintext, nil
|
return plaintext, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// EncryptMessage encrypts a message using AES-256-GCM
|
// EncryptMessage encrypts a message using AES-256-GCM
|
||||||
func (c *CryptoService) EncryptMessage(plaintext []byte) ([]byte, error) {
|
func (c *CryptoService) EncryptMessage(plaintext []byte) ([]byte, error) {
|
||||||
block, err := aes.NewCipher(c.masterKey)
|
block, err := aes.NewCipher(c.masterKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
aesGCM, err := cipher.NewGCM(block)
|
aesGCM, err := cipher.NewGCM(block)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nonce := make([]byte, aesGCM.NonceSize())
|
nonce := make([]byte, aesGCM.NonceSize())
|
||||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ciphertext := aesGCM.Seal(nonce, nonce, plaintext, nil)
|
ciphertext := aesGCM.Seal(nonce, nonce, plaintext, nil)
|
||||||
return ciphertext, nil
|
return ciphertext, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecryptMessage decrypts a message
|
// DecryptMessage decrypts a message
|
||||||
func (c *CryptoService) DecryptMessage(ciphertext []byte) ([]byte, error) {
|
func (c *CryptoService) DecryptMessage(ciphertext []byte) ([]byte, error) {
|
||||||
block, err := aes.NewCipher(c.masterKey)
|
block, err := aes.NewCipher(c.masterKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
aesGCM, err := cipher.NewGCM(block)
|
aesGCM, err := cipher.NewGCM(block)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nonceSize := aesGCM.NonceSize()
|
nonceSize := aesGCM.NonceSize()
|
||||||
if len(ciphertext) < nonceSize {
|
if len(ciphertext) < nonceSize {
|
||||||
return nil, ErrInvalidCipherText
|
return nil, ErrInvalidCipherText
|
||||||
}
|
}
|
||||||
|
|
||||||
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
||||||
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
|
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrDecryptionFailed
|
return nil, ErrDecryptionFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
return plaintext, nil
|
return plaintext, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hash256 computes SHA-256 hash
|
// Hash256 computes SHA-256 hash
|
||||||
func Hash256(data []byte) []byte {
|
func Hash256(data []byte) []byte {
|
||||||
hash := sha256.Sum256(data)
|
hash := sha256.Sum256(data)
|
||||||
return hash[:]
|
return hash[:]
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyECDSASignature verifies an ECDSA signature
|
// VerifyECDSASignature verifies an ECDSA signature
|
||||||
func VerifyECDSASignature(messageHash, signature, publicKey []byte) (bool, error) {
|
func VerifyECDSASignature(messageHash, signature, publicKey []byte) (bool, error) {
|
||||||
// Parse public key (assuming secp256k1/P256 uncompressed format)
|
// Parse public key (assuming secp256k1/P256 uncompressed format)
|
||||||
curve := elliptic.P256()
|
curve := elliptic.P256()
|
||||||
x, y := elliptic.Unmarshal(curve, publicKey)
|
x, y := elliptic.Unmarshal(curve, publicKey)
|
||||||
if x == nil {
|
if x == nil {
|
||||||
return false, ErrInvalidPublicKey
|
return false, ErrInvalidPublicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
pubKey := &ecdsa.PublicKey{
|
pubKey := &ecdsa.PublicKey{
|
||||||
Curve: curve,
|
Curve: curve,
|
||||||
X: x,
|
X: x,
|
||||||
Y: y,
|
Y: y,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse signature (R || S, each 32 bytes)
|
// Parse signature (R || S, each 32 bytes)
|
||||||
if len(signature) != 64 {
|
if len(signature) != 64 {
|
||||||
return false, ErrInvalidSignature
|
return false, ErrInvalidSignature
|
||||||
}
|
}
|
||||||
|
|
||||||
r := new(big.Int).SetBytes(signature[:32])
|
r := new(big.Int).SetBytes(signature[:32])
|
||||||
s := new(big.Int).SetBytes(signature[32:])
|
s := new(big.Int).SetBytes(signature[32:])
|
||||||
|
|
||||||
// Verify signature
|
// Verify signature
|
||||||
valid := ecdsa.Verify(pubKey, messageHash, r, s)
|
valid := ecdsa.Verify(pubKey, messageHash, r, s)
|
||||||
return valid, nil
|
return valid, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateNonce generates a cryptographic nonce
|
// GenerateNonce generates a cryptographic nonce
|
||||||
func GenerateNonce() ([]byte, error) {
|
func GenerateNonce() ([]byte, error) {
|
||||||
return GenerateRandomBytes(32)
|
return GenerateRandomBytes(32)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SecureCompare performs constant-time comparison
|
// SecureCompare performs constant-time comparison
|
||||||
func SecureCompare(a, b []byte) bool {
|
func SecureCompare(a, b []byte) bool {
|
||||||
if len(a) != len(b) {
|
if len(a) != len(b) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
var result byte
|
var result byte
|
||||||
for i := 0; i < len(a); i++ {
|
for i := 0; i < len(a); i++ {
|
||||||
result |= a[i] ^ b[i]
|
result |= a[i] ^ b[i]
|
||||||
}
|
}
|
||||||
return result == 0
|
return result == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParsePublicKey parses a public key from bytes (P256 uncompressed format)
|
// ParsePublicKey parses a public key from bytes (P256 uncompressed format)
|
||||||
func ParsePublicKey(publicKeyBytes []byte) (*ecdsa.PublicKey, error) {
|
func ParsePublicKey(publicKeyBytes []byte) (*ecdsa.PublicKey, error) {
|
||||||
curve := elliptic.P256()
|
curve := elliptic.P256()
|
||||||
x, y := elliptic.Unmarshal(curve, publicKeyBytes)
|
x, y := elliptic.Unmarshal(curve, publicKeyBytes)
|
||||||
if x == nil {
|
if x == nil {
|
||||||
return nil, ErrInvalidPublicKey
|
return nil, ErrInvalidPublicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ecdsa.PublicKey{
|
return &ecdsa.PublicKey{
|
||||||
Curve: curve,
|
Curve: curve,
|
||||||
X: x,
|
X: x,
|
||||||
Y: y,
|
Y: y,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifySignature verifies an ECDSA signature using a public key
|
// VerifySignature verifies an ECDSA signature using a public key
|
||||||
func VerifySignature(pubKey *ecdsa.PublicKey, messageHash, signature []byte) bool {
|
func VerifySignature(pubKey *ecdsa.PublicKey, messageHash, signature []byte) bool {
|
||||||
// Parse signature (R || S, each 32 bytes)
|
// Parse signature (R || S, each 32 bytes)
|
||||||
if len(signature) != 64 {
|
if len(signature) != 64 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
r := new(big.Int).SetBytes(signature[:32])
|
r := new(big.Int).SetBytes(signature[:32])
|
||||||
s := new(big.Int).SetBytes(signature[32:])
|
s := new(big.Int).SetBytes(signature[32:])
|
||||||
|
|
||||||
return ecdsa.Verify(pubKey, messageHash, r, s)
|
return ecdsa.Verify(pubKey, messageHash, r, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HashMessage computes SHA-256 hash of a message (alias for Hash256)
|
// HashMessage computes SHA-256 hash of a message (alias for Hash256)
|
||||||
func HashMessage(message []byte) []byte {
|
func HashMessage(message []byte) []byte {
|
||||||
return Hash256(message)
|
return Hash256(message)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encrypt encrypts data using AES-256-GCM with the provided key
|
// Encrypt encrypts data using AES-256-GCM with the provided key
|
||||||
func Encrypt(key, plaintext []byte) ([]byte, error) {
|
func Encrypt(key, plaintext []byte) ([]byte, error) {
|
||||||
if len(key) != 32 {
|
if len(key) != 32 {
|
||||||
return nil, ErrInvalidKeySize
|
return nil, ErrInvalidKeySize
|
||||||
}
|
}
|
||||||
|
|
||||||
block, err := aes.NewCipher(key)
|
block, err := aes.NewCipher(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
aesGCM, err := cipher.NewGCM(block)
|
aesGCM, err := cipher.NewGCM(block)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nonce := make([]byte, aesGCM.NonceSize())
|
nonce := make([]byte, aesGCM.NonceSize())
|
||||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ciphertext := aesGCM.Seal(nonce, nonce, plaintext, nil)
|
ciphertext := aesGCM.Seal(nonce, nonce, plaintext, nil)
|
||||||
return ciphertext, nil
|
return ciphertext, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrypt decrypts data using AES-256-GCM with the provided key
|
// Decrypt decrypts data using AES-256-GCM with the provided key
|
||||||
func Decrypt(key, ciphertext []byte) ([]byte, error) {
|
func Decrypt(key, ciphertext []byte) ([]byte, error) {
|
||||||
if len(key) != 32 {
|
if len(key) != 32 {
|
||||||
return nil, ErrInvalidKeySize
|
return nil, ErrInvalidKeySize
|
||||||
}
|
}
|
||||||
|
|
||||||
block, err := aes.NewCipher(key)
|
block, err := aes.NewCipher(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
aesGCM, err := cipher.NewGCM(block)
|
aesGCM, err := cipher.NewGCM(block)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nonceSize := aesGCM.NonceSize()
|
nonceSize := aesGCM.NonceSize()
|
||||||
if len(ciphertext) < nonceSize {
|
if len(ciphertext) < nonceSize {
|
||||||
return nil, ErrInvalidCipherText
|
return nil, ErrInvalidCipherText
|
||||||
}
|
}
|
||||||
|
|
||||||
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
||||||
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
|
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrDecryptionFailed
|
return nil, ErrDecryptionFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
return plaintext, nil
|
return plaintext, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeriveKey derives a key from secret and salt using HKDF (standalone function)
|
// DeriveKey derives a key from secret and salt using HKDF (standalone function)
|
||||||
func DeriveKey(secret, salt []byte, length int) ([]byte, error) {
|
func DeriveKey(secret, salt []byte, length int) ([]byte, error) {
|
||||||
hkdfReader := hkdf.New(sha256.New, secret, salt, nil)
|
hkdfReader := hkdf.New(sha256.New, secret, salt, nil)
|
||||||
key := make([]byte, length)
|
key := make([]byte, length)
|
||||||
if _, err := io.ReadFull(hkdfReader, key); err != nil {
|
if _, err := io.ReadFull(hkdfReader, key); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return key, nil
|
return key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignMessage signs a message using ECDSA private key
|
// SignMessage signs a message using ECDSA private key
|
||||||
func SignMessage(privateKey *ecdsa.PrivateKey, message []byte) ([]byte, error) {
|
func SignMessage(privateKey *ecdsa.PrivateKey, message []byte) ([]byte, error) {
|
||||||
hash := Hash256(message)
|
hash := Hash256(message)
|
||||||
r, s, err := ecdsa.Sign(rand.Reader, privateKey, hash)
|
r, s, err := ecdsa.Sign(rand.Reader, privateKey, hash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode R and S as 32 bytes each (total 64 bytes)
|
// Encode R and S as 32 bytes each (total 64 bytes)
|
||||||
signature := make([]byte, 64)
|
signature := make([]byte, 64)
|
||||||
rBytes := r.Bytes()
|
rBytes := r.Bytes()
|
||||||
sBytes := s.Bytes()
|
sBytes := s.Bytes()
|
||||||
|
|
||||||
// Pad with zeros if necessary
|
// Pad with zeros if necessary
|
||||||
copy(signature[32-len(rBytes):32], rBytes)
|
copy(signature[32-len(rBytes):32], rBytes)
|
||||||
copy(signature[64-len(sBytes):64], sBytes)
|
copy(signature[64-len(sBytes):64], sBytes)
|
||||||
|
|
||||||
return signature, nil
|
return signature, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// EncodeToHex encodes bytes to hex string
|
// EncodeToHex encodes bytes to hex string
|
||||||
func EncodeToHex(data []byte) string {
|
func EncodeToHex(data []byte) string {
|
||||||
return hex.EncodeToString(data)
|
return hex.EncodeToString(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecodeFromHex decodes hex string to bytes
|
// DecodeFromHex decodes hex string to bytes
|
||||||
func DecodeFromHex(s string) ([]byte, error) {
|
func DecodeFromHex(s string) ([]byte, error) {
|
||||||
return hex.DecodeString(s)
|
return hex.DecodeString(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// EncodeToBase64 encodes bytes to base64 string
|
// EncodeToBase64 encodes bytes to base64 string
|
||||||
func EncodeToBase64(data []byte) string {
|
func EncodeToBase64(data []byte) string {
|
||||||
return hex.EncodeToString(data) // Using hex for simplicity, could use base64
|
return hex.EncodeToString(data) // Using hex for simplicity, could use base64
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecodeFromBase64 decodes base64 string to bytes
|
// DecodeFromBase64 decodes base64 string to bytes
|
||||||
func DecodeFromBase64(s string) ([]byte, error) {
|
func DecodeFromBase64(s string) ([]byte, error) {
|
||||||
return hex.DecodeString(s)
|
return hex.DecodeString(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalPublicKey marshals an ECDSA public key to bytes
|
// MarshalPublicKey marshals an ECDSA public key to bytes
|
||||||
func MarshalPublicKey(pubKey *ecdsa.PublicKey) []byte {
|
func MarshalPublicKey(pubKey *ecdsa.PublicKey) []byte {
|
||||||
return elliptic.Marshal(pubKey.Curve, pubKey.X, pubKey.Y)
|
return elliptic.Marshal(pubKey.Curve, pubKey.X, pubKey.Y)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CompareBytes performs constant-time comparison of two byte slices
|
// CompareBytes performs constant-time comparison of two byte slices
|
||||||
func CompareBytes(a, b []byte) bool {
|
func CompareBytes(a, b []byte) bool {
|
||||||
return SecureCompare(a, b)
|
return SecureCompare(a, b)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,141 +1,141 @@
|
||||||
package errors
|
package errors
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Domain errors
|
// Domain errors
|
||||||
var (
|
var (
|
||||||
// Session errors
|
// Session errors
|
||||||
ErrSessionNotFound = errors.New("session not found")
|
ErrSessionNotFound = errors.New("session not found")
|
||||||
ErrSessionExpired = errors.New("session expired")
|
ErrSessionExpired = errors.New("session expired")
|
||||||
ErrSessionAlreadyExists = errors.New("session already exists")
|
ErrSessionAlreadyExists = errors.New("session already exists")
|
||||||
ErrSessionFull = errors.New("session is full")
|
ErrSessionFull = errors.New("session is full")
|
||||||
ErrSessionNotInProgress = errors.New("session not in progress")
|
ErrSessionNotInProgress = errors.New("session not in progress")
|
||||||
ErrInvalidSessionType = errors.New("invalid session type")
|
ErrInvalidSessionType = errors.New("invalid session type")
|
||||||
ErrInvalidThreshold = errors.New("invalid threshold: t cannot exceed n")
|
ErrInvalidThreshold = errors.New("invalid threshold: t cannot exceed n")
|
||||||
|
|
||||||
// Participant errors
|
// Participant errors
|
||||||
ErrParticipantNotFound = errors.New("participant not found")
|
ErrParticipantNotFound = errors.New("participant not found")
|
||||||
ErrParticipantNotInvited = errors.New("participant not invited")
|
ErrParticipantNotInvited = errors.New("participant not invited")
|
||||||
ErrInvalidJoinToken = errors.New("invalid join token")
|
ErrInvalidJoinToken = errors.New("invalid join token")
|
||||||
ErrTokenMismatch = errors.New("token mismatch")
|
ErrTokenMismatch = errors.New("token mismatch")
|
||||||
ErrParticipantAlreadyJoined = errors.New("participant already joined")
|
ErrParticipantAlreadyJoined = errors.New("participant already joined")
|
||||||
|
|
||||||
// Message errors
|
// Message errors
|
||||||
ErrMessageNotFound = errors.New("message not found")
|
ErrMessageNotFound = errors.New("message not found")
|
||||||
ErrInvalidMessage = errors.New("invalid message")
|
ErrInvalidMessage = errors.New("invalid message")
|
||||||
ErrMessageDeliveryFailed = errors.New("message delivery failed")
|
ErrMessageDeliveryFailed = errors.New("message delivery failed")
|
||||||
|
|
||||||
// Key share errors
|
// Key share errors
|
||||||
ErrKeyShareNotFound = errors.New("key share not found")
|
ErrKeyShareNotFound = errors.New("key share not found")
|
||||||
ErrKeyShareCorrupted = errors.New("key share corrupted")
|
ErrKeyShareCorrupted = errors.New("key share corrupted")
|
||||||
ErrDecryptionFailed = errors.New("decryption failed")
|
ErrDecryptionFailed = errors.New("decryption failed")
|
||||||
|
|
||||||
// Account errors
|
// Account errors
|
||||||
ErrAccountNotFound = errors.New("account not found")
|
ErrAccountNotFound = errors.New("account not found")
|
||||||
ErrAccountExists = errors.New("account already exists")
|
ErrAccountExists = errors.New("account already exists")
|
||||||
ErrAccountSuspended = errors.New("account suspended")
|
ErrAccountSuspended = errors.New("account suspended")
|
||||||
ErrInvalidCredentials = errors.New("invalid credentials")
|
ErrInvalidCredentials = errors.New("invalid credentials")
|
||||||
|
|
||||||
// Crypto errors
|
// Crypto errors
|
||||||
ErrInvalidPublicKey = errors.New("invalid public key")
|
ErrInvalidPublicKey = errors.New("invalid public key")
|
||||||
ErrInvalidSignature = errors.New("invalid signature")
|
ErrInvalidSignature = errors.New("invalid signature")
|
||||||
ErrSigningFailed = errors.New("signing failed")
|
ErrSigningFailed = errors.New("signing failed")
|
||||||
ErrKeygenFailed = errors.New("keygen failed")
|
ErrKeygenFailed = errors.New("keygen failed")
|
||||||
|
|
||||||
// Infrastructure errors
|
// Infrastructure errors
|
||||||
ErrDatabaseConnection = errors.New("database connection error")
|
ErrDatabaseConnection = errors.New("database connection error")
|
||||||
ErrCacheConnection = errors.New("cache connection error")
|
ErrCacheConnection = errors.New("cache connection error")
|
||||||
ErrQueueConnection = errors.New("queue connection error")
|
ErrQueueConnection = errors.New("queue connection error")
|
||||||
)
|
)
|
||||||
|
|
||||||
// DomainError represents a domain-specific error with additional context
|
// DomainError represents a domain-specific error with additional context
|
||||||
type DomainError struct {
|
type DomainError struct {
|
||||||
Err error
|
Err error
|
||||||
Message string
|
Message string
|
||||||
Code string
|
Code string
|
||||||
Details map[string]interface{}
|
Details map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DomainError) Error() string {
|
func (e *DomainError) Error() string {
|
||||||
if e.Message != "" {
|
if e.Message != "" {
|
||||||
return fmt.Sprintf("%s: %v", e.Message, e.Err)
|
return fmt.Sprintf("%s: %v", e.Message, e.Err)
|
||||||
}
|
}
|
||||||
return e.Err.Error()
|
return e.Err.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DomainError) Unwrap() error {
|
func (e *DomainError) Unwrap() error {
|
||||||
return e.Err
|
return e.Err
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDomainError creates a new domain error
|
// NewDomainError creates a new domain error
|
||||||
func NewDomainError(err error, code string, message string) *DomainError {
|
func NewDomainError(err error, code string, message string) *DomainError {
|
||||||
return &DomainError{
|
return &DomainError{
|
||||||
Err: err,
|
Err: err,
|
||||||
Code: code,
|
Code: code,
|
||||||
Message: message,
|
Message: message,
|
||||||
Details: make(map[string]interface{}),
|
Details: make(map[string]interface{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithDetail adds additional context to the error
|
// WithDetail adds additional context to the error
|
||||||
func (e *DomainError) WithDetail(key string, value interface{}) *DomainError {
|
func (e *DomainError) WithDetail(key string, value interface{}) *DomainError {
|
||||||
e.Details[key] = value
|
e.Details[key] = value
|
||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidationError represents input validation errors
|
// ValidationError represents input validation errors
|
||||||
type ValidationError struct {
|
type ValidationError struct {
|
||||||
Field string
|
Field string
|
||||||
Message string
|
Message string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ValidationError) Error() string {
|
func (e *ValidationError) Error() string {
|
||||||
return fmt.Sprintf("validation error on field '%s': %s", e.Field, e.Message)
|
return fmt.Sprintf("validation error on field '%s': %s", e.Field, e.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewValidationError creates a new validation error
|
// NewValidationError creates a new validation error
|
||||||
func NewValidationError(field, message string) *ValidationError {
|
func NewValidationError(field, message string) *ValidationError {
|
||||||
return &ValidationError{
|
return &ValidationError{
|
||||||
Field: field,
|
Field: field,
|
||||||
Message: message,
|
Message: message,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NotFoundError represents a resource not found error
|
// NotFoundError represents a resource not found error
|
||||||
type NotFoundError struct {
|
type NotFoundError struct {
|
||||||
Resource string
|
Resource string
|
||||||
ID string
|
ID string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *NotFoundError) Error() string {
|
func (e *NotFoundError) Error() string {
|
||||||
return fmt.Sprintf("%s with id '%s' not found", e.Resource, e.ID)
|
return fmt.Sprintf("%s with id '%s' not found", e.Resource, e.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNotFoundError creates a new not found error
|
// NewNotFoundError creates a new not found error
|
||||||
func NewNotFoundError(resource, id string) *NotFoundError {
|
func NewNotFoundError(resource, id string) *NotFoundError {
|
||||||
return &NotFoundError{
|
return &NotFoundError{
|
||||||
Resource: resource,
|
Resource: resource,
|
||||||
ID: id,
|
ID: id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Is checks if the target error matches
|
// Is checks if the target error matches
|
||||||
func Is(err, target error) bool {
|
func Is(err, target error) bool {
|
||||||
return errors.Is(err, target)
|
return errors.Is(err, target)
|
||||||
}
|
}
|
||||||
|
|
||||||
// As attempts to convert err to target type
|
// As attempts to convert err to target type
|
||||||
func As(err error, target interface{}) bool {
|
func As(err error, target interface{}) bool {
|
||||||
return errors.As(err, target)
|
return errors.As(err, target)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap wraps an error with additional context
|
// Wrap wraps an error with additional context
|
||||||
func Wrap(err error, message string) error {
|
func Wrap(err error, message string) error {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("%s: %w", message, err)
|
return fmt.Errorf("%s: %w", message, err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,234 +1,234 @@
|
||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidToken = errors.New("invalid token")
|
ErrInvalidToken = errors.New("invalid token")
|
||||||
ErrExpiredToken = errors.New("token expired")
|
ErrExpiredToken = errors.New("token expired")
|
||||||
ErrInvalidClaims = errors.New("invalid claims")
|
ErrInvalidClaims = errors.New("invalid claims")
|
||||||
ErrTokenNotYetValid = errors.New("token not yet valid")
|
ErrTokenNotYetValid = errors.New("token not yet valid")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Claims represents custom JWT claims
|
// Claims represents custom JWT claims
|
||||||
type Claims struct {
|
type Claims struct {
|
||||||
SessionID string `json:"session_id"`
|
SessionID string `json:"session_id"`
|
||||||
PartyID string `json:"party_id"`
|
PartyID string `json:"party_id"`
|
||||||
TokenType string `json:"token_type"` // "join", "access", "refresh"
|
TokenType string `json:"token_type"` // "join", "access", "refresh"
|
||||||
jwt.RegisteredClaims
|
jwt.RegisteredClaims
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWTService provides JWT operations
|
// JWTService provides JWT operations
|
||||||
type JWTService struct {
|
type JWTService struct {
|
||||||
secretKey []byte
|
secretKey []byte
|
||||||
issuer string
|
issuer string
|
||||||
tokenExpiry time.Duration
|
tokenExpiry time.Duration
|
||||||
refreshExpiry time.Duration
|
refreshExpiry time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewJWTService creates a new JWT service
|
// NewJWTService creates a new JWT service
|
||||||
func NewJWTService(secretKey string, issuer string, tokenExpiry, refreshExpiry time.Duration) *JWTService {
|
func NewJWTService(secretKey string, issuer string, tokenExpiry, refreshExpiry time.Duration) *JWTService {
|
||||||
return &JWTService{
|
return &JWTService{
|
||||||
secretKey: []byte(secretKey),
|
secretKey: []byte(secretKey),
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
tokenExpiry: tokenExpiry,
|
tokenExpiry: tokenExpiry,
|
||||||
refreshExpiry: refreshExpiry,
|
refreshExpiry: refreshExpiry,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateJoinToken generates a token for joining an MPC session
|
// GenerateJoinToken generates a token for joining an MPC session
|
||||||
func (s *JWTService) GenerateJoinToken(sessionID uuid.UUID, partyID string, expiresIn time.Duration) (string, error) {
|
func (s *JWTService) GenerateJoinToken(sessionID uuid.UUID, partyID string, expiresIn time.Duration) (string, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
claims := Claims{
|
claims := Claims{
|
||||||
SessionID: sessionID.String(),
|
SessionID: sessionID.String(),
|
||||||
PartyID: partyID,
|
PartyID: partyID,
|
||||||
TokenType: "join",
|
TokenType: "join",
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
ID: uuid.New().String(),
|
ID: uuid.New().String(),
|
||||||
Issuer: s.issuer,
|
Issuer: s.issuer,
|
||||||
Subject: partyID,
|
Subject: partyID,
|
||||||
IssuedAt: jwt.NewNumericDate(now),
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
NotBefore: jwt.NewNumericDate(now),
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
ExpiresAt: jwt.NewNumericDate(now.Add(expiresIn)),
|
ExpiresAt: jwt.NewNumericDate(now.Add(expiresIn)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
return token.SignedString(s.secretKey)
|
return token.SignedString(s.secretKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccessTokenClaims represents claims in an access token
|
// AccessTokenClaims represents claims in an access token
|
||||||
type AccessTokenClaims struct {
|
type AccessTokenClaims struct {
|
||||||
Subject string
|
Subject string
|
||||||
Username string
|
Username string
|
||||||
Issuer string
|
Issuer string
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateAccessToken generates an access token with username
|
// GenerateAccessToken generates an access token with username
|
||||||
func (s *JWTService) GenerateAccessToken(userID, username string) (string, error) {
|
func (s *JWTService) GenerateAccessToken(userID, username string) (string, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
claims := Claims{
|
claims := Claims{
|
||||||
TokenType: "access",
|
TokenType: "access",
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
ID: uuid.New().String(),
|
ID: uuid.New().String(),
|
||||||
Issuer: s.issuer,
|
Issuer: s.issuer,
|
||||||
Subject: userID,
|
Subject: userID,
|
||||||
IssuedAt: jwt.NewNumericDate(now),
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
NotBefore: jwt.NewNumericDate(now),
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
ExpiresAt: jwt.NewNumericDate(now.Add(s.tokenExpiry)),
|
ExpiresAt: jwt.NewNumericDate(now.Add(s.tokenExpiry)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
// Store username in PartyID field for access tokens
|
// Store username in PartyID field for access tokens
|
||||||
claims.PartyID = username
|
claims.PartyID = username
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
return token.SignedString(s.secretKey)
|
return token.SignedString(s.secretKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateRefreshToken generates a refresh token
|
// GenerateRefreshToken generates a refresh token
|
||||||
func (s *JWTService) GenerateRefreshToken(userID string) (string, error) {
|
func (s *JWTService) GenerateRefreshToken(userID string) (string, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
claims := Claims{
|
claims := Claims{
|
||||||
TokenType: "refresh",
|
TokenType: "refresh",
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
ID: uuid.New().String(),
|
ID: uuid.New().String(),
|
||||||
Issuer: s.issuer,
|
Issuer: s.issuer,
|
||||||
Subject: userID,
|
Subject: userID,
|
||||||
IssuedAt: jwt.NewNumericDate(now),
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
NotBefore: jwt.NewNumericDate(now),
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
ExpiresAt: jwt.NewNumericDate(now.Add(s.refreshExpiry)),
|
ExpiresAt: jwt.NewNumericDate(now.Add(s.refreshExpiry)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
return token.SignedString(s.secretKey)
|
return token.SignedString(s.secretKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateToken validates a JWT token and returns the claims
|
// ValidateToken validates a JWT token and returns the claims
|
||||||
func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) {
|
func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) {
|
||||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
return nil, ErrInvalidToken
|
return nil, ErrInvalidToken
|
||||||
}
|
}
|
||||||
return s.secretKey, nil
|
return s.secretKey, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||||
return nil, ErrExpiredToken
|
return nil, ErrExpiredToken
|
||||||
}
|
}
|
||||||
return nil, ErrInvalidToken
|
return nil, ErrInvalidToken
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, ok := token.Claims.(*Claims)
|
claims, ok := token.Claims.(*Claims)
|
||||||
if !ok || !token.Valid {
|
if !ok || !token.Valid {
|
||||||
return nil, ErrInvalidClaims
|
return nil, ErrInvalidClaims
|
||||||
}
|
}
|
||||||
|
|
||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseJoinTokenClaims parses a join token and extracts claims without validating session ID
|
// ParseJoinTokenClaims parses a join token and extracts claims without validating session ID
|
||||||
// This is used when the session ID is not known beforehand (e.g., join by token)
|
// This is used when the session ID is not known beforehand (e.g., join by token)
|
||||||
func (s *JWTService) ParseJoinTokenClaims(tokenString string) (*Claims, error) {
|
func (s *JWTService) ParseJoinTokenClaims(tokenString string) (*Claims, error) {
|
||||||
claims, err := s.ValidateToken(tokenString)
|
claims, err := s.ValidateToken(tokenString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if claims.TokenType != "join" {
|
if claims.TokenType != "join" {
|
||||||
return nil, ErrInvalidToken
|
return nil, ErrInvalidToken
|
||||||
}
|
}
|
||||||
|
|
||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateJoinToken validates a join token for MPC sessions
|
// ValidateJoinToken validates a join token for MPC sessions
|
||||||
func (s *JWTService) ValidateJoinToken(tokenString string, sessionID uuid.UUID, partyID string) (*Claims, error) {
|
func (s *JWTService) ValidateJoinToken(tokenString string, sessionID uuid.UUID, partyID string) (*Claims, error) {
|
||||||
claims, err := s.ValidateToken(tokenString)
|
claims, err := s.ValidateToken(tokenString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if claims.TokenType != "join" {
|
if claims.TokenType != "join" {
|
||||||
return nil, ErrInvalidToken
|
return nil, ErrInvalidToken
|
||||||
}
|
}
|
||||||
|
|
||||||
if claims.SessionID != sessionID.String() {
|
if claims.SessionID != sessionID.String() {
|
||||||
return nil, ErrInvalidClaims
|
return nil, ErrInvalidClaims
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allow wildcard party ID "*" for dynamic joining, otherwise must match exactly
|
// Allow wildcard party ID "*" for dynamic joining, otherwise must match exactly
|
||||||
if claims.PartyID != "*" && claims.PartyID != partyID {
|
if claims.PartyID != "*" && claims.PartyID != partyID {
|
||||||
return nil, ErrInvalidClaims
|
return nil, ErrInvalidClaims
|
||||||
}
|
}
|
||||||
|
|
||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshAccessToken creates a new access token from a valid refresh token
|
// RefreshAccessToken creates a new access token from a valid refresh token
|
||||||
func (s *JWTService) RefreshAccessToken(refreshToken string) (string, error) {
|
func (s *JWTService) RefreshAccessToken(refreshToken string) (string, error) {
|
||||||
claims, err := s.ValidateToken(refreshToken)
|
claims, err := s.ValidateToken(refreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if claims.TokenType != "refresh" {
|
if claims.TokenType != "refresh" {
|
||||||
return "", ErrInvalidToken
|
return "", ErrInvalidToken
|
||||||
}
|
}
|
||||||
|
|
||||||
// PartyID stores the username for access tokens
|
// PartyID stores the username for access tokens
|
||||||
return s.GenerateAccessToken(claims.Subject, claims.PartyID)
|
return s.GenerateAccessToken(claims.Subject, claims.PartyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateAccessToken validates an access token and returns structured claims
|
// ValidateAccessToken validates an access token and returns structured claims
|
||||||
func (s *JWTService) ValidateAccessToken(tokenString string) (*AccessTokenClaims, error) {
|
func (s *JWTService) ValidateAccessToken(tokenString string) (*AccessTokenClaims, error) {
|
||||||
claims, err := s.ValidateToken(tokenString)
|
claims, err := s.ValidateToken(tokenString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if claims.TokenType != "access" {
|
if claims.TokenType != "access" {
|
||||||
return nil, ErrInvalidToken
|
return nil, ErrInvalidToken
|
||||||
}
|
}
|
||||||
|
|
||||||
return &AccessTokenClaims{
|
return &AccessTokenClaims{
|
||||||
Subject: claims.Subject,
|
Subject: claims.Subject,
|
||||||
Username: claims.PartyID, // Username stored in PartyID for access tokens
|
Username: claims.PartyID, // Username stored in PartyID for access tokens
|
||||||
Issuer: claims.Issuer,
|
Issuer: claims.Issuer,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateRefreshToken validates a refresh token and returns claims
|
// ValidateRefreshToken validates a refresh token and returns claims
|
||||||
func (s *JWTService) ValidateRefreshToken(tokenString string) (*Claims, error) {
|
func (s *JWTService) ValidateRefreshToken(tokenString string) (*Claims, error) {
|
||||||
claims, err := s.ValidateToken(tokenString)
|
claims, err := s.ValidateToken(tokenString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if claims.TokenType != "refresh" {
|
if claims.TokenType != "refresh" {
|
||||||
return nil, ErrInvalidToken
|
return nil, ErrInvalidToken
|
||||||
}
|
}
|
||||||
|
|
||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenGenerator interface for dependency injection
|
// TokenGenerator interface for dependency injection
|
||||||
type TokenGenerator interface {
|
type TokenGenerator interface {
|
||||||
GenerateJoinToken(sessionID uuid.UUID, partyID string, expiresIn time.Duration) (string, error)
|
GenerateJoinToken(sessionID uuid.UUID, partyID string, expiresIn time.Duration) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenValidator interface for dependency injection
|
// TokenValidator interface for dependency injection
|
||||||
type TokenValidator interface {
|
type TokenValidator interface {
|
||||||
ParseJoinTokenClaims(tokenString string) (*Claims, error)
|
ParseJoinTokenClaims(tokenString string) (*Claims, error)
|
||||||
ValidateJoinToken(tokenString string, sessionID uuid.UUID, partyID string) (*Claims, error)
|
ValidateJoinToken(tokenString string, sessionID uuid.UUID, partyID string) (*Claims, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure JWTService implements interfaces
|
// Ensure JWTService implements interfaces
|
||||||
var _ TokenGenerator = (*JWTService)(nil)
|
var _ TokenGenerator = (*JWTService)(nil)
|
||||||
var _ TokenValidator = (*JWTService)(nil)
|
var _ TokenValidator = (*JWTService)(nil)
|
||||||
|
|
|
||||||
|
|
@ -1,169 +1,169 @@
|
||||||
package logger
|
package logger
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"go.uber.org/zap/zapcore"
|
"go.uber.org/zap/zapcore"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
Log *zap.Logger
|
Log *zap.Logger
|
||||||
Sugar *zap.SugaredLogger
|
Sugar *zap.SugaredLogger
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config holds logger configuration
|
// Config holds logger configuration
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Level string `mapstructure:"level"`
|
Level string `mapstructure:"level"`
|
||||||
Encoding string `mapstructure:"encoding"`
|
Encoding string `mapstructure:"encoding"`
|
||||||
OutputPath string `mapstructure:"output_path"`
|
OutputPath string `mapstructure:"output_path"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes the global logger
|
// Init initializes the global logger
|
||||||
func Init(cfg *Config) error {
|
func Init(cfg *Config) error {
|
||||||
level := zapcore.InfoLevel
|
level := zapcore.InfoLevel
|
||||||
if cfg != nil && cfg.Level != "" {
|
if cfg != nil && cfg.Level != "" {
|
||||||
if err := level.UnmarshalText([]byte(cfg.Level)); err != nil {
|
if err := level.UnmarshalText([]byte(cfg.Level)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
encoding := "json"
|
encoding := "json"
|
||||||
if cfg != nil && cfg.Encoding != "" {
|
if cfg != nil && cfg.Encoding != "" {
|
||||||
encoding = cfg.Encoding
|
encoding = cfg.Encoding
|
||||||
}
|
}
|
||||||
|
|
||||||
outputPath := "stdout"
|
outputPath := "stdout"
|
||||||
if cfg != nil && cfg.OutputPath != "" {
|
if cfg != nil && cfg.OutputPath != "" {
|
||||||
outputPath = cfg.OutputPath
|
outputPath = cfg.OutputPath
|
||||||
}
|
}
|
||||||
|
|
||||||
zapConfig := zap.Config{
|
zapConfig := zap.Config{
|
||||||
Level: zap.NewAtomicLevelAt(level),
|
Level: zap.NewAtomicLevelAt(level),
|
||||||
Development: false,
|
Development: false,
|
||||||
DisableCaller: false,
|
DisableCaller: false,
|
||||||
DisableStacktrace: false,
|
DisableStacktrace: false,
|
||||||
Sampling: nil,
|
Sampling: nil,
|
||||||
Encoding: encoding,
|
Encoding: encoding,
|
||||||
EncoderConfig: zapcore.EncoderConfig{
|
EncoderConfig: zapcore.EncoderConfig{
|
||||||
MessageKey: "message",
|
MessageKey: "message",
|
||||||
LevelKey: "level",
|
LevelKey: "level",
|
||||||
TimeKey: "time",
|
TimeKey: "time",
|
||||||
NameKey: "logger",
|
NameKey: "logger",
|
||||||
CallerKey: "caller",
|
CallerKey: "caller",
|
||||||
FunctionKey: zapcore.OmitKey,
|
FunctionKey: zapcore.OmitKey,
|
||||||
StacktraceKey: "stacktrace",
|
StacktraceKey: "stacktrace",
|
||||||
LineEnding: zapcore.DefaultLineEnding,
|
LineEnding: zapcore.DefaultLineEnding,
|
||||||
EncodeLevel: zapcore.LowercaseLevelEncoder,
|
EncodeLevel: zapcore.LowercaseLevelEncoder,
|
||||||
EncodeTime: zapcore.ISO8601TimeEncoder,
|
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||||
EncodeDuration: zapcore.SecondsDurationEncoder,
|
EncodeDuration: zapcore.SecondsDurationEncoder,
|
||||||
EncodeCaller: zapcore.ShortCallerEncoder,
|
EncodeCaller: zapcore.ShortCallerEncoder,
|
||||||
},
|
},
|
||||||
OutputPaths: []string{outputPath},
|
OutputPaths: []string{outputPath},
|
||||||
ErrorOutputPaths: []string{"stderr"},
|
ErrorOutputPaths: []string{"stderr"},
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
Log, err = zapConfig.Build()
|
Log, err = zapConfig.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
Sugar = Log.Sugar()
|
Sugar = Log.Sugar()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitDevelopment initializes logger for development environment
|
// InitDevelopment initializes logger for development environment
|
||||||
func InitDevelopment() error {
|
func InitDevelopment() error {
|
||||||
var err error
|
var err error
|
||||||
Log, err = zap.NewDevelopment()
|
Log, err = zap.NewDevelopment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
Sugar = Log.Sugar()
|
Sugar = Log.Sugar()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitProduction initializes logger for production environment
|
// InitProduction initializes logger for production environment
|
||||||
func InitProduction() error {
|
func InitProduction() error {
|
||||||
var err error
|
var err error
|
||||||
Log, err = zap.NewProduction()
|
Log, err = zap.NewProduction()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
Sugar = Log.Sugar()
|
Sugar = Log.Sugar()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sync flushes any buffered log entries
|
// Sync flushes any buffered log entries
|
||||||
func Sync() error {
|
func Sync() error {
|
||||||
if Log != nil {
|
if Log != nil {
|
||||||
return Log.Sync()
|
return Log.Sync()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithFields creates a new logger with additional fields
|
// WithFields creates a new logger with additional fields
|
||||||
func WithFields(fields ...zap.Field) *zap.Logger {
|
func WithFields(fields ...zap.Field) *zap.Logger {
|
||||||
return Log.With(fields...)
|
return Log.With(fields...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Debug logs a debug message
|
// Debug logs a debug message
|
||||||
func Debug(msg string, fields ...zap.Field) {
|
func Debug(msg string, fields ...zap.Field) {
|
||||||
Log.Debug(msg, fields...)
|
Log.Debug(msg, fields...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Info logs an info message
|
// Info logs an info message
|
||||||
func Info(msg string, fields ...zap.Field) {
|
func Info(msg string, fields ...zap.Field) {
|
||||||
Log.Info(msg, fields...)
|
Log.Info(msg, fields...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Warn logs a warning message
|
// Warn logs a warning message
|
||||||
func Warn(msg string, fields ...zap.Field) {
|
func Warn(msg string, fields ...zap.Field) {
|
||||||
Log.Warn(msg, fields...)
|
Log.Warn(msg, fields...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error logs an error message
|
// Error logs an error message
|
||||||
func Error(msg string, fields ...zap.Field) {
|
func Error(msg string, fields ...zap.Field) {
|
||||||
Log.Error(msg, fields...)
|
Log.Error(msg, fields...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fatal logs a fatal message and exits
|
// Fatal logs a fatal message and exits
|
||||||
func Fatal(msg string, fields ...zap.Field) {
|
func Fatal(msg string, fields ...zap.Field) {
|
||||||
Log.Fatal(msg, fields...)
|
Log.Fatal(msg, fields...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Panic logs a panic message and panics
|
// Panic logs a panic message and panics
|
||||||
func Panic(msg string, fields ...zap.Field) {
|
func Panic(msg string, fields ...zap.Field) {
|
||||||
Log.Panic(msg, fields...)
|
Log.Panic(msg, fields...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Field creates a zap field
|
// Field creates a zap field
|
||||||
func Field(key string, value interface{}) zap.Field {
|
func Field(key string, value interface{}) zap.Field {
|
||||||
return zap.Any(key, value)
|
return zap.Any(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// String creates a string field
|
// String creates a string field
|
||||||
func String(key, value string) zap.Field {
|
func String(key, value string) zap.Field {
|
||||||
return zap.String(key, value)
|
return zap.String(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Int creates an int field
|
// Int creates an int field
|
||||||
func Int(key string, value int) zap.Field {
|
func Int(key string, value int) zap.Field {
|
||||||
return zap.Int(key, value)
|
return zap.Int(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Err creates an error field
|
// Err creates an error field
|
||||||
func Err(err error) zap.Field {
|
func Err(err error) zap.Field {
|
||||||
return zap.Error(err)
|
return zap.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// Initialize with development logger by default
|
// Initialize with development logger by default
|
||||||
// This will be overridden when Init() is called with proper config
|
// This will be overridden when Init() is called with proper config
|
||||||
if os.Getenv("ENV") == "production" {
|
if os.Getenv("ENV") == "production" {
|
||||||
_ = InitProduction()
|
_ = InitProduction()
|
||||||
} else {
|
} else {
|
||||||
_ = InitDevelopment()
|
_ = InitDevelopment()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,405 +1,405 @@
|
||||||
package tss
|
package tss
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
|
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
|
||||||
"github.com/bnb-chain/tss-lib/v2/tss"
|
"github.com/bnb-chain/tss-lib/v2/tss"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrKeygenTimeout = errors.New("keygen timeout")
|
ErrKeygenTimeout = errors.New("keygen timeout")
|
||||||
ErrKeygenFailed = errors.New("keygen failed")
|
ErrKeygenFailed = errors.New("keygen failed")
|
||||||
ErrInvalidPartyCount = errors.New("invalid party count")
|
ErrInvalidPartyCount = errors.New("invalid party count")
|
||||||
ErrInvalidThreshold = errors.New("invalid threshold")
|
ErrInvalidThreshold = errors.New("invalid threshold")
|
||||||
)
|
)
|
||||||
|
|
||||||
// KeygenResult contains the result of a keygen operation
|
// KeygenResult contains the result of a keygen operation
|
||||||
type KeygenResult struct {
|
type KeygenResult struct {
|
||||||
// LocalPartySaveData is the serialized save data for this party
|
// LocalPartySaveData is the serialized save data for this party
|
||||||
LocalPartySaveData []byte
|
LocalPartySaveData []byte
|
||||||
// PublicKey is the group ECDSA public key
|
// PublicKey is the group ECDSA public key
|
||||||
PublicKey *ecdsa.PublicKey
|
PublicKey *ecdsa.PublicKey
|
||||||
// PublicKeyBytes is the compressed public key bytes
|
// PublicKeyBytes is the compressed public key bytes
|
||||||
PublicKeyBytes []byte
|
PublicKeyBytes []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// KeygenParty represents a party participating in keygen
|
// KeygenParty represents a party participating in keygen
|
||||||
type KeygenParty struct {
|
type KeygenParty struct {
|
||||||
PartyID string
|
PartyID string
|
||||||
PartyIndex int
|
PartyIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
// KeygenConfig contains configuration for keygen
|
// KeygenConfig contains configuration for keygen
|
||||||
type KeygenConfig struct {
|
type KeygenConfig struct {
|
||||||
Threshold int // t in t-of-n
|
Threshold int // t in t-of-n
|
||||||
TotalParties int // n in t-of-n
|
TotalParties int // n in t-of-n
|
||||||
Timeout time.Duration // Keygen timeout
|
Timeout time.Duration // Keygen timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
// KeygenSession manages a keygen session for a single party
|
// KeygenSession manages a keygen session for a single party
|
||||||
type KeygenSession struct {
|
type KeygenSession struct {
|
||||||
config KeygenConfig
|
config KeygenConfig
|
||||||
selfParty KeygenParty
|
selfParty KeygenParty
|
||||||
allParties []KeygenParty
|
allParties []KeygenParty
|
||||||
tssPartyIDs []*tss.PartyID
|
tssPartyIDs []*tss.PartyID
|
||||||
selfTSSID *tss.PartyID
|
selfTSSID *tss.PartyID
|
||||||
params *tss.Parameters
|
params *tss.Parameters
|
||||||
localParty tss.Party
|
localParty tss.Party
|
||||||
outCh chan tss.Message
|
outCh chan tss.Message
|
||||||
endCh chan *keygen.LocalPartySaveData
|
endCh chan *keygen.LocalPartySaveData
|
||||||
errCh chan error
|
errCh chan error
|
||||||
msgHandler MessageHandler
|
msgHandler MessageHandler
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
started bool
|
started bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// MessageHandler handles outgoing and incoming TSS messages
|
// MessageHandler handles outgoing and incoming TSS messages
|
||||||
type MessageHandler interface {
|
type MessageHandler interface {
|
||||||
// SendMessage sends a message to other parties
|
// SendMessage sends a message to other parties
|
||||||
SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error
|
SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error
|
||||||
// ReceiveMessages returns a channel for receiving messages
|
// ReceiveMessages returns a channel for receiving messages
|
||||||
ReceiveMessages() <-chan *ReceivedMessage
|
ReceiveMessages() <-chan *ReceivedMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReceivedMessage represents a received TSS message
|
// ReceivedMessage represents a received TSS message
|
||||||
type ReceivedMessage struct {
|
type ReceivedMessage struct {
|
||||||
FromPartyIndex int
|
FromPartyIndex int
|
||||||
IsBroadcast bool
|
IsBroadcast bool
|
||||||
MsgBytes []byte
|
MsgBytes []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewKeygenSession creates a new keygen session
|
// NewKeygenSession creates a new keygen session
|
||||||
func NewKeygenSession(
|
func NewKeygenSession(
|
||||||
config KeygenConfig,
|
config KeygenConfig,
|
||||||
selfParty KeygenParty,
|
selfParty KeygenParty,
|
||||||
allParties []KeygenParty,
|
allParties []KeygenParty,
|
||||||
msgHandler MessageHandler,
|
msgHandler MessageHandler,
|
||||||
) (*KeygenSession, error) {
|
) (*KeygenSession, error) {
|
||||||
if config.TotalParties < 2 {
|
if config.TotalParties < 2 {
|
||||||
return nil, ErrInvalidPartyCount
|
return nil, ErrInvalidPartyCount
|
||||||
}
|
}
|
||||||
if config.Threshold < 1 || config.Threshold > config.TotalParties {
|
if config.Threshold < 1 || config.Threshold > config.TotalParties {
|
||||||
return nil, ErrInvalidThreshold
|
return nil, ErrInvalidThreshold
|
||||||
}
|
}
|
||||||
if len(allParties) != config.TotalParties {
|
if len(allParties) != config.TotalParties {
|
||||||
return nil, ErrInvalidPartyCount
|
return nil, ErrInvalidPartyCount
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create TSS party IDs
|
// Create TSS party IDs
|
||||||
tssPartyIDs := make([]*tss.PartyID, len(allParties))
|
tssPartyIDs := make([]*tss.PartyID, len(allParties))
|
||||||
var selfTSSID *tss.PartyID
|
var selfTSSID *tss.PartyID
|
||||||
for i, p := range allParties {
|
for i, p := range allParties {
|
||||||
partyID := tss.NewPartyID(
|
partyID := tss.NewPartyID(
|
||||||
p.PartyID,
|
p.PartyID,
|
||||||
fmt.Sprintf("party-%d", p.PartyIndex),
|
fmt.Sprintf("party-%d", p.PartyIndex),
|
||||||
big.NewInt(int64(p.PartyIndex+1)),
|
big.NewInt(int64(p.PartyIndex+1)),
|
||||||
)
|
)
|
||||||
tssPartyIDs[i] = partyID
|
tssPartyIDs[i] = partyID
|
||||||
if p.PartyID == selfParty.PartyID {
|
if p.PartyID == selfParty.PartyID {
|
||||||
selfTSSID = partyID
|
selfTSSID = partyID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if selfTSSID == nil {
|
if selfTSSID == nil {
|
||||||
return nil, errors.New("self party not found in all parties")
|
return nil, errors.New("self party not found in all parties")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort party IDs
|
// Sort party IDs
|
||||||
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
|
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
|
||||||
|
|
||||||
// Create peer context and parameters
|
// Create peer context and parameters
|
||||||
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
||||||
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), config.Threshold)
|
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), config.Threshold)
|
||||||
|
|
||||||
return &KeygenSession{
|
return &KeygenSession{
|
||||||
config: config,
|
config: config,
|
||||||
selfParty: selfParty,
|
selfParty: selfParty,
|
||||||
allParties: allParties,
|
allParties: allParties,
|
||||||
tssPartyIDs: sortedPartyIDs,
|
tssPartyIDs: sortedPartyIDs,
|
||||||
selfTSSID: selfTSSID,
|
selfTSSID: selfTSSID,
|
||||||
params: params,
|
params: params,
|
||||||
outCh: make(chan tss.Message, config.TotalParties*10),
|
outCh: make(chan tss.Message, config.TotalParties*10),
|
||||||
endCh: make(chan *keygen.LocalPartySaveData, 1),
|
endCh: make(chan *keygen.LocalPartySaveData, 1),
|
||||||
errCh: make(chan error, 1),
|
errCh: make(chan error, 1),
|
||||||
msgHandler: msgHandler,
|
msgHandler: msgHandler,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start begins the keygen protocol
|
// Start begins the keygen protocol
|
||||||
func (s *KeygenSession) Start(ctx context.Context) (*KeygenResult, error) {
|
func (s *KeygenSession) Start(ctx context.Context) (*KeygenResult, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if s.started {
|
if s.started {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return nil, errors.New("session already started")
|
return nil, errors.New("session already started")
|
||||||
}
|
}
|
||||||
s.started = true
|
s.started = true
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
// Create local party
|
// Create local party
|
||||||
s.localParty = keygen.NewLocalParty(s.params, s.outCh, s.endCh)
|
s.localParty = keygen.NewLocalParty(s.params, s.outCh, s.endCh)
|
||||||
|
|
||||||
// Start the local party
|
// Start the local party
|
||||||
go func() {
|
go func() {
|
||||||
if err := s.localParty.Start(); err != nil {
|
if err := s.localParty.Start(); err != nil {
|
||||||
s.errCh <- err
|
s.errCh <- err
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Handle outgoing messages
|
// Handle outgoing messages
|
||||||
go s.handleOutgoingMessages(ctx)
|
go s.handleOutgoingMessages(ctx)
|
||||||
|
|
||||||
// Handle incoming messages
|
// Handle incoming messages
|
||||||
go s.handleIncomingMessages(ctx)
|
go s.handleIncomingMessages(ctx)
|
||||||
|
|
||||||
// Wait for completion or timeout
|
// Wait for completion or timeout
|
||||||
timeout := s.config.Timeout
|
timeout := s.config.Timeout
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = 10 * time.Minute
|
timeout = 10 * time.Minute
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
case <-time.After(timeout):
|
case <-time.After(timeout):
|
||||||
return nil, ErrKeygenTimeout
|
return nil, ErrKeygenTimeout
|
||||||
case tssErr := <-s.errCh:
|
case tssErr := <-s.errCh:
|
||||||
return nil, fmt.Errorf("%w: %v", ErrKeygenFailed, tssErr)
|
return nil, fmt.Errorf("%w: %v", ErrKeygenFailed, tssErr)
|
||||||
case saveData := <-s.endCh:
|
case saveData := <-s.endCh:
|
||||||
return s.buildResult(saveData)
|
return s.buildResult(saveData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *KeygenSession) handleOutgoingMessages(ctx context.Context) {
|
func (s *KeygenSession) handleOutgoingMessages(ctx context.Context) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case msg := <-s.outCh:
|
case msg := <-s.outCh:
|
||||||
if msg == nil {
|
if msg == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
msgBytes, _, err := msg.WireBytes()
|
msgBytes, _, err := msg.WireBytes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var toParties []string
|
var toParties []string
|
||||||
isBroadcast := msg.IsBroadcast()
|
isBroadcast := msg.IsBroadcast()
|
||||||
if !isBroadcast {
|
if !isBroadcast {
|
||||||
for _, to := range msg.GetTo() {
|
for _, to := range msg.GetTo() {
|
||||||
toParties = append(toParties, to.Id)
|
toParties = append(toParties, to.Id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.msgHandler.SendMessage(ctx, isBroadcast, toParties, msgBytes); err != nil {
|
if err := s.msgHandler.SendMessage(ctx, isBroadcast, toParties, msgBytes); err != nil {
|
||||||
// Log error but continue
|
// Log error but continue
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *KeygenSession) handleIncomingMessages(ctx context.Context) {
|
func (s *KeygenSession) handleIncomingMessages(ctx context.Context) {
|
||||||
msgCh := s.msgHandler.ReceiveMessages()
|
msgCh := s.msgHandler.ReceiveMessages()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case msg, ok := <-msgCh:
|
case msg, ok := <-msgCh:
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the message
|
// Parse the message
|
||||||
parsedMsg, err := tss.ParseWireMessage(msg.MsgBytes, s.tssPartyIDs[msg.FromPartyIndex], msg.IsBroadcast)
|
parsedMsg, err := tss.ParseWireMessage(msg.MsgBytes, s.tssPartyIDs[msg.FromPartyIndex], msg.IsBroadcast)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the party
|
// Update the party
|
||||||
go func() {
|
go func() {
|
||||||
ok, err := s.localParty.Update(parsedMsg)
|
ok, err := s.localParty.Update(parsedMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.errCh <- err
|
s.errCh <- err
|
||||||
}
|
}
|
||||||
_ = ok
|
_ = ok
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *KeygenSession) buildResult(saveData *keygen.LocalPartySaveData) (*KeygenResult, error) {
|
func (s *KeygenSession) buildResult(saveData *keygen.LocalPartySaveData) (*KeygenResult, error) {
|
||||||
// Serialize save data
|
// Serialize save data
|
||||||
saveDataBytes, err := json.Marshal(saveData)
|
saveDataBytes, err := json.Marshal(saveData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to serialize save data: %w", err)
|
return nil, fmt.Errorf("failed to serialize save data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get public key
|
// Get public key
|
||||||
pubKey := saveData.ECDSAPub.ToECDSAPubKey()
|
pubKey := saveData.ECDSAPub.ToECDSAPubKey()
|
||||||
|
|
||||||
// Compress public key
|
// Compress public key
|
||||||
pubKeyBytes := make([]byte, 33)
|
pubKeyBytes := make([]byte, 33)
|
||||||
pubKeyBytes[0] = 0x02 + byte(pubKey.Y.Bit(0))
|
pubKeyBytes[0] = 0x02 + byte(pubKey.Y.Bit(0))
|
||||||
xBytes := pubKey.X.Bytes()
|
xBytes := pubKey.X.Bytes()
|
||||||
copy(pubKeyBytes[33-len(xBytes):], xBytes)
|
copy(pubKeyBytes[33-len(xBytes):], xBytes)
|
||||||
|
|
||||||
return &KeygenResult{
|
return &KeygenResult{
|
||||||
LocalPartySaveData: saveDataBytes,
|
LocalPartySaveData: saveDataBytes,
|
||||||
PublicKey: pubKey,
|
PublicKey: pubKey,
|
||||||
PublicKeyBytes: pubKeyBytes,
|
PublicKeyBytes: pubKeyBytes,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalKeygenResult contains local keygen result for standalone testing
|
// LocalKeygenResult contains local keygen result for standalone testing
|
||||||
type LocalKeygenResult struct {
|
type LocalKeygenResult struct {
|
||||||
SaveData *keygen.LocalPartySaveData
|
SaveData *keygen.LocalPartySaveData
|
||||||
PublicKey *ecdsa.PublicKey
|
PublicKey *ecdsa.PublicKey
|
||||||
PartyIndex int
|
PartyIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunLocalKeygen runs keygen locally with all parties in the same process (for testing)
|
// RunLocalKeygen runs keygen locally with all parties in the same process (for testing)
|
||||||
func RunLocalKeygen(threshold, totalParties int) ([]*LocalKeygenResult, error) {
|
func RunLocalKeygen(threshold, totalParties int) ([]*LocalKeygenResult, error) {
|
||||||
if totalParties < 2 {
|
if totalParties < 2 {
|
||||||
return nil, ErrInvalidPartyCount
|
return nil, ErrInvalidPartyCount
|
||||||
}
|
}
|
||||||
if threshold < 1 || threshold > totalParties {
|
if threshold < 1 || threshold > totalParties {
|
||||||
return nil, ErrInvalidThreshold
|
return nil, ErrInvalidThreshold
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create party IDs
|
// Create party IDs
|
||||||
partyIDs := make([]*tss.PartyID, totalParties)
|
partyIDs := make([]*tss.PartyID, totalParties)
|
||||||
for i := 0; i < totalParties; i++ {
|
for i := 0; i < totalParties; i++ {
|
||||||
partyIDs[i] = tss.NewPartyID(
|
partyIDs[i] = tss.NewPartyID(
|
||||||
fmt.Sprintf("party-%d", i),
|
fmt.Sprintf("party-%d", i),
|
||||||
fmt.Sprintf("party-%d", i),
|
fmt.Sprintf("party-%d", i),
|
||||||
big.NewInt(int64(i+1)),
|
big.NewInt(int64(i+1)),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
sortedPartyIDs := tss.SortPartyIDs(partyIDs)
|
sortedPartyIDs := tss.SortPartyIDs(partyIDs)
|
||||||
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
||||||
|
|
||||||
// Create channels for each party
|
// Create channels for each party
|
||||||
outChs := make([]chan tss.Message, totalParties)
|
outChs := make([]chan tss.Message, totalParties)
|
||||||
endChs := make([]chan *keygen.LocalPartySaveData, totalParties)
|
endChs := make([]chan *keygen.LocalPartySaveData, totalParties)
|
||||||
parties := make([]tss.Party, totalParties)
|
parties := make([]tss.Party, totalParties)
|
||||||
|
|
||||||
for i := 0; i < totalParties; i++ {
|
for i := 0; i < totalParties; i++ {
|
||||||
outChs[i] = make(chan tss.Message, totalParties*10)
|
outChs[i] = make(chan tss.Message, totalParties*10)
|
||||||
endChs[i] = make(chan *keygen.LocalPartySaveData, 1)
|
endChs[i] = make(chan *keygen.LocalPartySaveData, 1)
|
||||||
params := tss.NewParameters(tss.S256(), peerCtx, sortedPartyIDs[i], totalParties, threshold)
|
params := tss.NewParameters(tss.S256(), peerCtx, sortedPartyIDs[i], totalParties, threshold)
|
||||||
parties[i] = keygen.NewLocalParty(params, outChs[i], endChs[i])
|
parties[i] = keygen.NewLocalParty(params, outChs[i], endChs[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start all parties
|
// Start all parties
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
errCh := make(chan error, totalParties)
|
errCh := make(chan error, totalParties)
|
||||||
|
|
||||||
for i := 0; i < totalParties; i++ {
|
for i := 0; i < totalParties; i++ {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(idx int) {
|
go func(idx int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
if err := parties[idx].Start(); err != nil {
|
if err := parties[idx].Start(); err != nil {
|
||||||
errCh <- err
|
errCh <- err
|
||||||
}
|
}
|
||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Route messages between parties
|
// Route messages between parties
|
||||||
var routeWg sync.WaitGroup
|
var routeWg sync.WaitGroup
|
||||||
doneCh := make(chan struct{})
|
doneCh := make(chan struct{})
|
||||||
|
|
||||||
for i := 0; i < totalParties; i++ {
|
for i := 0; i < totalParties; i++ {
|
||||||
routeWg.Add(1)
|
routeWg.Add(1)
|
||||||
go func(idx int) {
|
go func(idx int) {
|
||||||
defer routeWg.Done()
|
defer routeWg.Done()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-doneCh:
|
case <-doneCh:
|
||||||
return
|
return
|
||||||
case msg := <-outChs[idx]:
|
case msg := <-outChs[idx]:
|
||||||
if msg == nil {
|
if msg == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dest := msg.GetTo()
|
dest := msg.GetTo()
|
||||||
if msg.IsBroadcast() {
|
if msg.IsBroadcast() {
|
||||||
for j := 0; j < totalParties; j++ {
|
for j := 0; j < totalParties; j++ {
|
||||||
if j != idx {
|
if j != idx {
|
||||||
go updateParty(parties[j], msg, errCh)
|
go updateParty(parties[j], msg, errCh)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, d := range dest {
|
for _, d := range dest {
|
||||||
for j := 0; j < totalParties; j++ {
|
for j := 0; j < totalParties; j++ {
|
||||||
if sortedPartyIDs[j].Id == d.Id {
|
if sortedPartyIDs[j].Id == d.Id {
|
||||||
go updateParty(parties[j], msg, errCh)
|
go updateParty(parties[j], msg, errCh)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collect results
|
// Collect results
|
||||||
results := make([]*LocalKeygenResult, totalParties)
|
results := make([]*LocalKeygenResult, totalParties)
|
||||||
for i := 0; i < totalParties; i++ {
|
for i := 0; i < totalParties; i++ {
|
||||||
select {
|
select {
|
||||||
case saveData := <-endChs[i]:
|
case saveData := <-endChs[i]:
|
||||||
results[i] = &LocalKeygenResult{
|
results[i] = &LocalKeygenResult{
|
||||||
SaveData: saveData,
|
SaveData: saveData,
|
||||||
PublicKey: saveData.ECDSAPub.ToECDSAPubKey(),
|
PublicKey: saveData.ECDSAPub.ToECDSAPubKey(),
|
||||||
PartyIndex: i,
|
PartyIndex: i,
|
||||||
}
|
}
|
||||||
case err := <-errCh:
|
case err := <-errCh:
|
||||||
close(doneCh)
|
close(doneCh)
|
||||||
return nil, err
|
return nil, err
|
||||||
case <-time.After(5 * time.Minute):
|
case <-time.After(5 * time.Minute):
|
||||||
close(doneCh)
|
close(doneCh)
|
||||||
return nil, ErrKeygenTimeout
|
return nil, ErrKeygenTimeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
close(doneCh)
|
close(doneCh)
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateParty(party tss.Party, msg tss.Message, errCh chan error) {
|
func updateParty(party tss.Party, msg tss.Message, errCh chan error) {
|
||||||
bytes, routing, err := msg.WireBytes()
|
bytes, routing, err := msg.WireBytes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errCh <- err
|
errCh <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
parsedMsg, err := tss.ParseWireMessage(bytes, msg.GetFrom(), routing.IsBroadcast)
|
parsedMsg, err := tss.ParseWireMessage(bytes, msg.GetFrom(), routing.IsBroadcast)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errCh <- err
|
errCh <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, err := party.Update(parsedMsg); err != nil {
|
if _, err := party.Update(parsedMsg); err != nil {
|
||||||
// Only send error if it's not a duplicate message error
|
// Only send error if it's not a duplicate message error
|
||||||
// Check if error message contains "duplicate message" indication
|
// Check if error message contains "duplicate message" indication
|
||||||
if err.Error() != "" && !isDuplicateMessageError(err) {
|
if err.Error() != "" && !isDuplicateMessageError(err) {
|
||||||
errCh <- err
|
errCh <- err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isDuplicateMessageError checks if an error is a duplicate message error
|
// isDuplicateMessageError checks if an error is a duplicate message error
|
||||||
func isDuplicateMessageError(err error) bool {
|
func isDuplicateMessageError(err error) bool {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
errStr := err.Error()
|
errStr := err.Error()
|
||||||
return strings.Contains(errStr, "duplicate") || strings.Contains(errStr, "already received")
|
return strings.Contains(errStr, "duplicate") || strings.Contains(errStr, "already received")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,435 +1,435 @@
|
||||||
package tss
|
package tss
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/bnb-chain/tss-lib/v2/common"
|
"github.com/bnb-chain/tss-lib/v2/common"
|
||||||
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
|
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
|
||||||
"github.com/bnb-chain/tss-lib/v2/ecdsa/signing"
|
"github.com/bnb-chain/tss-lib/v2/ecdsa/signing"
|
||||||
"github.com/bnb-chain/tss-lib/v2/tss"
|
"github.com/bnb-chain/tss-lib/v2/tss"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrSigningTimeout = errors.New("signing timeout")
|
ErrSigningTimeout = errors.New("signing timeout")
|
||||||
ErrSigningFailed = errors.New("signing failed")
|
ErrSigningFailed = errors.New("signing failed")
|
||||||
ErrInvalidSignerCount = errors.New("invalid signer count")
|
ErrInvalidSignerCount = errors.New("invalid signer count")
|
||||||
ErrInvalidShareData = errors.New("invalid share data")
|
ErrInvalidShareData = errors.New("invalid share data")
|
||||||
)
|
)
|
||||||
|
|
||||||
// SigningResult contains the result of a signing operation
|
// SigningResult contains the result of a signing operation
|
||||||
type SigningResult struct {
|
type SigningResult struct {
|
||||||
// Signature is the full ECDSA signature (R || S)
|
// Signature is the full ECDSA signature (R || S)
|
||||||
Signature []byte
|
Signature []byte
|
||||||
// R is the R component of the signature
|
// R is the R component of the signature
|
||||||
R *big.Int
|
R *big.Int
|
||||||
// S is the S component of the signature
|
// S is the S component of the signature
|
||||||
S *big.Int
|
S *big.Int
|
||||||
// RecoveryID is the recovery ID for ecrecover
|
// RecoveryID is the recovery ID for ecrecover
|
||||||
RecoveryID int
|
RecoveryID int
|
||||||
}
|
}
|
||||||
|
|
||||||
// SigningParty represents a party participating in signing
|
// SigningParty represents a party participating in signing
|
||||||
type SigningParty struct {
|
type SigningParty struct {
|
||||||
PartyID string
|
PartyID string
|
||||||
PartyIndex int
|
PartyIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
// SigningConfig contains configuration for signing
|
// SigningConfig contains configuration for signing
|
||||||
type SigningConfig struct {
|
type SigningConfig struct {
|
||||||
Threshold int // t in t-of-n (number of signers required)
|
Threshold int // t in t-of-n (number of signers required)
|
||||||
TotalSigners int // Number of parties participating in this signing
|
TotalSigners int // Number of parties participating in this signing
|
||||||
Timeout time.Duration // Signing timeout
|
Timeout time.Duration // Signing timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
// SigningSession manages a signing session for a single party
|
// SigningSession manages a signing session for a single party
|
||||||
type SigningSession struct {
|
type SigningSession struct {
|
||||||
config SigningConfig
|
config SigningConfig
|
||||||
selfParty SigningParty
|
selfParty SigningParty
|
||||||
allParties []SigningParty
|
allParties []SigningParty
|
||||||
messageHash *big.Int
|
messageHash *big.Int
|
||||||
saveData *keygen.LocalPartySaveData
|
saveData *keygen.LocalPartySaveData
|
||||||
tssPartyIDs []*tss.PartyID
|
tssPartyIDs []*tss.PartyID
|
||||||
selfTSSID *tss.PartyID
|
selfTSSID *tss.PartyID
|
||||||
params *tss.Parameters
|
params *tss.Parameters
|
||||||
localParty tss.Party
|
localParty tss.Party
|
||||||
outCh chan tss.Message
|
outCh chan tss.Message
|
||||||
endCh chan *common.SignatureData
|
endCh chan *common.SignatureData
|
||||||
errCh chan error
|
errCh chan error
|
||||||
msgHandler MessageHandler
|
msgHandler MessageHandler
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
started bool
|
started bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSigningSession creates a new signing session
|
// NewSigningSession creates a new signing session
|
||||||
func NewSigningSession(
|
func NewSigningSession(
|
||||||
config SigningConfig,
|
config SigningConfig,
|
||||||
selfParty SigningParty,
|
selfParty SigningParty,
|
||||||
allParties []SigningParty,
|
allParties []SigningParty,
|
||||||
messageHash []byte,
|
messageHash []byte,
|
||||||
saveDataBytes []byte,
|
saveDataBytes []byte,
|
||||||
msgHandler MessageHandler,
|
msgHandler MessageHandler,
|
||||||
) (*SigningSession, error) {
|
) (*SigningSession, error) {
|
||||||
if config.TotalSigners < config.Threshold {
|
if config.TotalSigners < config.Threshold {
|
||||||
return nil, ErrInvalidSignerCount
|
return nil, ErrInvalidSignerCount
|
||||||
}
|
}
|
||||||
if len(allParties) != config.TotalSigners {
|
if len(allParties) != config.TotalSigners {
|
||||||
return nil, ErrInvalidSignerCount
|
return nil, ErrInvalidSignerCount
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deserialize save data
|
// Deserialize save data
|
||||||
var saveData keygen.LocalPartySaveData
|
var saveData keygen.LocalPartySaveData
|
||||||
if err := json.Unmarshal(saveDataBytes, &saveData); err != nil {
|
if err := json.Unmarshal(saveDataBytes, &saveData); err != nil {
|
||||||
return nil, fmt.Errorf("%w: %v", ErrInvalidShareData, err)
|
return nil, fmt.Errorf("%w: %v", ErrInvalidShareData, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create TSS party IDs for signers
|
// Create TSS party IDs for signers
|
||||||
tssPartyIDs := make([]*tss.PartyID, len(allParties))
|
tssPartyIDs := make([]*tss.PartyID, len(allParties))
|
||||||
var selfTSSID *tss.PartyID
|
var selfTSSID *tss.PartyID
|
||||||
for i, p := range allParties {
|
for i, p := range allParties {
|
||||||
partyID := tss.NewPartyID(
|
partyID := tss.NewPartyID(
|
||||||
p.PartyID,
|
p.PartyID,
|
||||||
fmt.Sprintf("party-%d", p.PartyIndex),
|
fmt.Sprintf("party-%d", p.PartyIndex),
|
||||||
big.NewInt(int64(p.PartyIndex+1)),
|
big.NewInt(int64(p.PartyIndex+1)),
|
||||||
)
|
)
|
||||||
tssPartyIDs[i] = partyID
|
tssPartyIDs[i] = partyID
|
||||||
if p.PartyID == selfParty.PartyID {
|
if p.PartyID == selfParty.PartyID {
|
||||||
selfTSSID = partyID
|
selfTSSID = partyID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if selfTSSID == nil {
|
if selfTSSID == nil {
|
||||||
return nil, errors.New("self party not found in all parties")
|
return nil, errors.New("self party not found in all parties")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort party IDs
|
// Sort party IDs
|
||||||
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
|
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
|
||||||
|
|
||||||
// Create peer context and parameters
|
// Create peer context and parameters
|
||||||
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
||||||
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), config.Threshold)
|
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), config.Threshold)
|
||||||
|
|
||||||
// Convert message hash to big.Int
|
// Convert message hash to big.Int
|
||||||
msgHash := new(big.Int).SetBytes(messageHash)
|
msgHash := new(big.Int).SetBytes(messageHash)
|
||||||
|
|
||||||
return &SigningSession{
|
return &SigningSession{
|
||||||
config: config,
|
config: config,
|
||||||
selfParty: selfParty,
|
selfParty: selfParty,
|
||||||
allParties: allParties,
|
allParties: allParties,
|
||||||
messageHash: msgHash,
|
messageHash: msgHash,
|
||||||
saveData: &saveData,
|
saveData: &saveData,
|
||||||
tssPartyIDs: sortedPartyIDs,
|
tssPartyIDs: sortedPartyIDs,
|
||||||
selfTSSID: selfTSSID,
|
selfTSSID: selfTSSID,
|
||||||
params: params,
|
params: params,
|
||||||
outCh: make(chan tss.Message, config.TotalSigners*10),
|
outCh: make(chan tss.Message, config.TotalSigners*10),
|
||||||
endCh: make(chan *common.SignatureData, 1),
|
endCh: make(chan *common.SignatureData, 1),
|
||||||
errCh: make(chan error, 1),
|
errCh: make(chan error, 1),
|
||||||
msgHandler: msgHandler,
|
msgHandler: msgHandler,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start begins the signing protocol
|
// Start begins the signing protocol
|
||||||
func (s *SigningSession) Start(ctx context.Context) (*SigningResult, error) {
|
func (s *SigningSession) Start(ctx context.Context) (*SigningResult, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if s.started {
|
if s.started {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return nil, errors.New("session already started")
|
return nil, errors.New("session already started")
|
||||||
}
|
}
|
||||||
s.started = true
|
s.started = true
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
// Create local party for signing
|
// Create local party for signing
|
||||||
s.localParty = signing.NewLocalParty(s.messageHash, s.params, *s.saveData, s.outCh, s.endCh)
|
s.localParty = signing.NewLocalParty(s.messageHash, s.params, *s.saveData, s.outCh, s.endCh)
|
||||||
|
|
||||||
// Start the local party
|
// Start the local party
|
||||||
go func() {
|
go func() {
|
||||||
if err := s.localParty.Start(); err != nil {
|
if err := s.localParty.Start(); err != nil {
|
||||||
s.errCh <- err
|
s.errCh <- err
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Handle outgoing messages
|
// Handle outgoing messages
|
||||||
go s.handleOutgoingMessages(ctx)
|
go s.handleOutgoingMessages(ctx)
|
||||||
|
|
||||||
// Handle incoming messages
|
// Handle incoming messages
|
||||||
go s.handleIncomingMessages(ctx)
|
go s.handleIncomingMessages(ctx)
|
||||||
|
|
||||||
// Wait for completion or timeout
|
// Wait for completion or timeout
|
||||||
timeout := s.config.Timeout
|
timeout := s.config.Timeout
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = 5 * time.Minute
|
timeout = 5 * time.Minute
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
case <-time.After(timeout):
|
case <-time.After(timeout):
|
||||||
return nil, ErrSigningTimeout
|
return nil, ErrSigningTimeout
|
||||||
case tssErr := <-s.errCh:
|
case tssErr := <-s.errCh:
|
||||||
return nil, fmt.Errorf("%w: %v", ErrSigningFailed, tssErr)
|
return nil, fmt.Errorf("%w: %v", ErrSigningFailed, tssErr)
|
||||||
case signData := <-s.endCh:
|
case signData := <-s.endCh:
|
||||||
return s.buildResult(signData)
|
return s.buildResult(signData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SigningSession) handleOutgoingMessages(ctx context.Context) {
|
func (s *SigningSession) handleOutgoingMessages(ctx context.Context) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case msg := <-s.outCh:
|
case msg := <-s.outCh:
|
||||||
if msg == nil {
|
if msg == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
msgBytes, _, err := msg.WireBytes()
|
msgBytes, _, err := msg.WireBytes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var toParties []string
|
var toParties []string
|
||||||
isBroadcast := msg.IsBroadcast()
|
isBroadcast := msg.IsBroadcast()
|
||||||
if !isBroadcast {
|
if !isBroadcast {
|
||||||
for _, to := range msg.GetTo() {
|
for _, to := range msg.GetTo() {
|
||||||
toParties = append(toParties, to.Id)
|
toParties = append(toParties, to.Id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.msgHandler.SendMessage(ctx, isBroadcast, toParties, msgBytes); err != nil {
|
if err := s.msgHandler.SendMessage(ctx, isBroadcast, toParties, msgBytes); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SigningSession) handleIncomingMessages(ctx context.Context) {
|
func (s *SigningSession) handleIncomingMessages(ctx context.Context) {
|
||||||
msgCh := s.msgHandler.ReceiveMessages()
|
msgCh := s.msgHandler.ReceiveMessages()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case msg, ok := <-msgCh:
|
case msg, ok := <-msgCh:
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the message
|
// Parse the message
|
||||||
parsedMsg, err := tss.ParseWireMessage(msg.MsgBytes, s.tssPartyIDs[msg.FromPartyIndex], msg.IsBroadcast)
|
parsedMsg, err := tss.ParseWireMessage(msg.MsgBytes, s.tssPartyIDs[msg.FromPartyIndex], msg.IsBroadcast)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the party
|
// Update the party
|
||||||
go func() {
|
go func() {
|
||||||
ok, err := s.localParty.Update(parsedMsg)
|
ok, err := s.localParty.Update(parsedMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.errCh <- err
|
s.errCh <- err
|
||||||
}
|
}
|
||||||
_ = ok
|
_ = ok
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SigningSession) buildResult(signData *common.SignatureData) (*SigningResult, error) {
|
func (s *SigningSession) buildResult(signData *common.SignatureData) (*SigningResult, error) {
|
||||||
// Get R and S as big.Int
|
// Get R and S as big.Int
|
||||||
r := new(big.Int).SetBytes(signData.R)
|
r := new(big.Int).SetBytes(signData.R)
|
||||||
rS := new(big.Int).SetBytes(signData.S)
|
rS := new(big.Int).SetBytes(signData.S)
|
||||||
|
|
||||||
// Build full signature (R || S)
|
// Build full signature (R || S)
|
||||||
signature := make([]byte, 64)
|
signature := make([]byte, 64)
|
||||||
rBytes := signData.R
|
rBytes := signData.R
|
||||||
sBytes := signData.S
|
sBytes := signData.S
|
||||||
|
|
||||||
// Pad to 32 bytes each
|
// Pad to 32 bytes each
|
||||||
copy(signature[32-len(rBytes):32], rBytes)
|
copy(signature[32-len(rBytes):32], rBytes)
|
||||||
copy(signature[64-len(sBytes):64], sBytes)
|
copy(signature[64-len(sBytes):64], sBytes)
|
||||||
|
|
||||||
// Calculate recovery ID
|
// Calculate recovery ID
|
||||||
recoveryID := int(signData.SignatureRecovery[0])
|
recoveryID := int(signData.SignatureRecovery[0])
|
||||||
|
|
||||||
return &SigningResult{
|
return &SigningResult{
|
||||||
Signature: signature,
|
Signature: signature,
|
||||||
R: r,
|
R: r,
|
||||||
S: rS,
|
S: rS,
|
||||||
RecoveryID: recoveryID,
|
RecoveryID: recoveryID,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalSigningResult contains local signing result for standalone testing
|
// LocalSigningResult contains local signing result for standalone testing
|
||||||
type LocalSigningResult struct {
|
type LocalSigningResult struct {
|
||||||
Signature []byte
|
Signature []byte
|
||||||
R *big.Int
|
R *big.Int
|
||||||
S *big.Int
|
S *big.Int
|
||||||
RecoveryID int
|
RecoveryID int
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunLocalSigning runs signing locally with all parties in the same process (for testing)
|
// RunLocalSigning runs signing locally with all parties in the same process (for testing)
|
||||||
func RunLocalSigning(
|
func RunLocalSigning(
|
||||||
threshold int,
|
threshold int,
|
||||||
keygenResults []*LocalKeygenResult,
|
keygenResults []*LocalKeygenResult,
|
||||||
messageHash []byte,
|
messageHash []byte,
|
||||||
) (*LocalSigningResult, error) {
|
) (*LocalSigningResult, error) {
|
||||||
signerCount := len(keygenResults)
|
signerCount := len(keygenResults)
|
||||||
if signerCount < threshold {
|
if signerCount < threshold {
|
||||||
return nil, ErrInvalidSignerCount
|
return nil, ErrInvalidSignerCount
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create party IDs for signers using their ORIGINAL party indices from keygen
|
// Create party IDs for signers using their ORIGINAL party indices from keygen
|
||||||
// This is critical for subset signing - party IDs must match the original keygen party IDs
|
// This is critical for subset signing - party IDs must match the original keygen party IDs
|
||||||
partyIDs := make([]*tss.PartyID, signerCount)
|
partyIDs := make([]*tss.PartyID, signerCount)
|
||||||
for i, result := range keygenResults {
|
for i, result := range keygenResults {
|
||||||
idx := result.PartyIndex
|
idx := result.PartyIndex
|
||||||
partyIDs[i] = tss.NewPartyID(
|
partyIDs[i] = tss.NewPartyID(
|
||||||
fmt.Sprintf("party-%d", idx),
|
fmt.Sprintf("party-%d", idx),
|
||||||
fmt.Sprintf("party-%d", idx),
|
fmt.Sprintf("party-%d", idx),
|
||||||
big.NewInt(int64(idx+1)),
|
big.NewInt(int64(idx+1)),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
sortedPartyIDs := tss.SortPartyIDs(partyIDs)
|
sortedPartyIDs := tss.SortPartyIDs(partyIDs)
|
||||||
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
||||||
|
|
||||||
// Convert message hash to big.Int
|
// Convert message hash to big.Int
|
||||||
msgHash := new(big.Int).SetBytes(messageHash)
|
msgHash := new(big.Int).SetBytes(messageHash)
|
||||||
|
|
||||||
// Create channels for each party
|
// Create channels for each party
|
||||||
outChs := make([]chan tss.Message, signerCount)
|
outChs := make([]chan tss.Message, signerCount)
|
||||||
endChs := make([]chan *common.SignatureData, signerCount)
|
endChs := make([]chan *common.SignatureData, signerCount)
|
||||||
parties := make([]tss.Party, signerCount)
|
parties := make([]tss.Party, signerCount)
|
||||||
|
|
||||||
// Map sorted party IDs back to keygen results
|
// Map sorted party IDs back to keygen results
|
||||||
sortedKeygenResults := make([]*LocalKeygenResult, signerCount)
|
sortedKeygenResults := make([]*LocalKeygenResult, signerCount)
|
||||||
for i, pid := range sortedPartyIDs {
|
for i, pid := range sortedPartyIDs {
|
||||||
for _, result := range keygenResults {
|
for _, result := range keygenResults {
|
||||||
expectedID := fmt.Sprintf("party-%d", result.PartyIndex)
|
expectedID := fmt.Sprintf("party-%d", result.PartyIndex)
|
||||||
if pid.Id == expectedID {
|
if pid.Id == expectedID {
|
||||||
sortedKeygenResults[i] = result
|
sortedKeygenResults[i] = result
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < signerCount; i++ {
|
for i := 0; i < signerCount; i++ {
|
||||||
outChs[i] = make(chan tss.Message, signerCount*10)
|
outChs[i] = make(chan tss.Message, signerCount*10)
|
||||||
endChs[i] = make(chan *common.SignatureData, 1)
|
endChs[i] = make(chan *common.SignatureData, 1)
|
||||||
params := tss.NewParameters(tss.S256(), peerCtx, sortedPartyIDs[i], signerCount, threshold)
|
params := tss.NewParameters(tss.S256(), peerCtx, sortedPartyIDs[i], signerCount, threshold)
|
||||||
parties[i] = signing.NewLocalParty(msgHash, params, *sortedKeygenResults[i].SaveData, outChs[i], endChs[i])
|
parties[i] = signing.NewLocalParty(msgHash, params, *sortedKeygenResults[i].SaveData, outChs[i], endChs[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start all parties
|
// Start all parties
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
errCh := make(chan error, signerCount)
|
errCh := make(chan error, signerCount)
|
||||||
|
|
||||||
for i := 0; i < signerCount; i++ {
|
for i := 0; i < signerCount; i++ {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(idx int) {
|
go func(idx int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
if err := parties[idx].Start(); err != nil {
|
if err := parties[idx].Start(); err != nil {
|
||||||
errCh <- err
|
errCh <- err
|
||||||
}
|
}
|
||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Route messages between parties
|
// Route messages between parties
|
||||||
var routeWg sync.WaitGroup
|
var routeWg sync.WaitGroup
|
||||||
doneCh := make(chan struct{})
|
doneCh := make(chan struct{})
|
||||||
|
|
||||||
for i := 0; i < signerCount; i++ {
|
for i := 0; i < signerCount; i++ {
|
||||||
routeWg.Add(1)
|
routeWg.Add(1)
|
||||||
go func(idx int) {
|
go func(idx int) {
|
||||||
defer routeWg.Done()
|
defer routeWg.Done()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-doneCh:
|
case <-doneCh:
|
||||||
return
|
return
|
||||||
case msg := <-outChs[idx]:
|
case msg := <-outChs[idx]:
|
||||||
if msg == nil {
|
if msg == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dest := msg.GetTo()
|
dest := msg.GetTo()
|
||||||
if msg.IsBroadcast() {
|
if msg.IsBroadcast() {
|
||||||
for j := 0; j < signerCount; j++ {
|
for j := 0; j < signerCount; j++ {
|
||||||
if j != idx {
|
if j != idx {
|
||||||
go updateSignParty(parties[j], msg, errCh)
|
go updateSignParty(parties[j], msg, errCh)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, d := range dest {
|
for _, d := range dest {
|
||||||
for j := 0; j < signerCount; j++ {
|
for j := 0; j < signerCount; j++ {
|
||||||
if sortedPartyIDs[j].Id == d.Id {
|
if sortedPartyIDs[j].Id == d.Id {
|
||||||
go updateSignParty(parties[j], msg, errCh)
|
go updateSignParty(parties[j], msg, errCh)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collect first result (all parties should produce same signature)
|
// Collect first result (all parties should produce same signature)
|
||||||
var result *LocalSigningResult
|
var result *LocalSigningResult
|
||||||
for i := 0; i < signerCount; i++ {
|
for i := 0; i < signerCount; i++ {
|
||||||
select {
|
select {
|
||||||
case signData := <-endChs[i]:
|
case signData := <-endChs[i]:
|
||||||
if result == nil {
|
if result == nil {
|
||||||
r := new(big.Int).SetBytes(signData.R)
|
r := new(big.Int).SetBytes(signData.R)
|
||||||
rS := new(big.Int).SetBytes(signData.S)
|
rS := new(big.Int).SetBytes(signData.S)
|
||||||
|
|
||||||
signature := make([]byte, 64)
|
signature := make([]byte, 64)
|
||||||
copy(signature[32-len(signData.R):32], signData.R)
|
copy(signature[32-len(signData.R):32], signData.R)
|
||||||
copy(signature[64-len(signData.S):64], signData.S)
|
copy(signature[64-len(signData.S):64], signData.S)
|
||||||
|
|
||||||
result = &LocalSigningResult{
|
result = &LocalSigningResult{
|
||||||
Signature: signature,
|
Signature: signature,
|
||||||
R: r,
|
R: r,
|
||||||
S: rS,
|
S: rS,
|
||||||
RecoveryID: int(signData.SignatureRecovery[0]),
|
RecoveryID: int(signData.SignatureRecovery[0]),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case err := <-errCh:
|
case err := <-errCh:
|
||||||
close(doneCh)
|
close(doneCh)
|
||||||
return nil, err
|
return nil, err
|
||||||
case <-time.After(5 * time.Minute):
|
case <-time.After(5 * time.Minute):
|
||||||
close(doneCh)
|
close(doneCh)
|
||||||
return nil, ErrSigningTimeout
|
return nil, ErrSigningTimeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
close(doneCh)
|
close(doneCh)
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateSignParty(party tss.Party, msg tss.Message, errCh chan error) {
|
func updateSignParty(party tss.Party, msg tss.Message, errCh chan error) {
|
||||||
bytes, routing, err := msg.WireBytes()
|
bytes, routing, err := msg.WireBytes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errCh <- err
|
errCh <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
parsedMsg, err := tss.ParseWireMessage(bytes, msg.GetFrom(), routing.IsBroadcast)
|
parsedMsg, err := tss.ParseWireMessage(bytes, msg.GetFrom(), routing.IsBroadcast)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errCh <- err
|
errCh <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, err := party.Update(parsedMsg); err != nil {
|
if _, err := party.Update(parsedMsg); err != nil {
|
||||||
// Only send error if it's not a duplicate message error
|
// Only send error if it's not a duplicate message error
|
||||||
if err.Error() != "" && !isSignDuplicateMessageError(err) {
|
if err.Error() != "" && !isSignDuplicateMessageError(err) {
|
||||||
errCh <- err
|
errCh <- err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isSignDuplicateMessageError checks if an error is a duplicate message error
|
// isSignDuplicateMessageError checks if an error is a duplicate message error
|
||||||
func isSignDuplicateMessageError(err error) bool {
|
func isSignDuplicateMessageError(err error) bool {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
errStr := err.Error()
|
errStr := err.Error()
|
||||||
return strings.Contains(errStr, "duplicate") || strings.Contains(errStr, "already received")
|
return strings.Contains(errStr, "duplicate") || strings.Contains(errStr, "already received")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,476 +1,476 @@
|
||||||
package tss
|
package tss
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
stdecdsa "crypto/ecdsa"
|
stdecdsa "crypto/ecdsa"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"math/big"
|
"math/big"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec/v2"
|
"github.com/btcsuite/btcd/btcec/v2"
|
||||||
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
|
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestRunLocalKeygen tests the local keygen functionality
|
// TestRunLocalKeygen tests the local keygen functionality
|
||||||
func TestRunLocalKeygen(t *testing.T) {
|
func TestRunLocalKeygen(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
threshold int
|
threshold int
|
||||||
totalParties int
|
totalParties int
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "2-of-3 keygen",
|
name: "2-of-3 keygen",
|
||||||
threshold: 2,
|
threshold: 2,
|
||||||
totalParties: 3,
|
totalParties: 3,
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "2-of-2 keygen",
|
name: "2-of-2 keygen",
|
||||||
threshold: 2,
|
threshold: 2,
|
||||||
totalParties: 2,
|
totalParties: 2,
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid party count",
|
name: "invalid party count",
|
||||||
threshold: 2,
|
threshold: 2,
|
||||||
totalParties: 1,
|
totalParties: 1,
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid threshold",
|
name: "invalid threshold",
|
||||||
threshold: 0,
|
threshold: 0,
|
||||||
totalParties: 3,
|
totalParties: 3,
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "threshold greater than parties",
|
name: "threshold greater than parties",
|
||||||
threshold: 4,
|
threshold: 4,
|
||||||
totalParties: 3,
|
totalParties: 3,
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
results, err := RunLocalKeygen(tt.threshold, tt.totalParties)
|
results, err := RunLocalKeygen(tt.threshold, tt.totalParties)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("RunLocalKeygen() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("RunLocalKeygen() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify results
|
// Verify results
|
||||||
if len(results) != tt.totalParties {
|
if len(results) != tt.totalParties {
|
||||||
t.Errorf("Expected %d results, got %d", tt.totalParties, len(results))
|
t.Errorf("Expected %d results, got %d", tt.totalParties, len(results))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify all parties have the same public key
|
// Verify all parties have the same public key
|
||||||
var firstPubKey *stdecdsa.PublicKey
|
var firstPubKey *stdecdsa.PublicKey
|
||||||
for i, result := range results {
|
for i, result := range results {
|
||||||
if result.SaveData == nil {
|
if result.SaveData == nil {
|
||||||
t.Errorf("Party %d has nil SaveData", i)
|
t.Errorf("Party %d has nil SaveData", i)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if result.PublicKey == nil {
|
if result.PublicKey == nil {
|
||||||
t.Errorf("Party %d has nil PublicKey", i)
|
t.Errorf("Party %d has nil PublicKey", i)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if firstPubKey == nil {
|
if firstPubKey == nil {
|
||||||
firstPubKey = result.PublicKey
|
firstPubKey = result.PublicKey
|
||||||
} else {
|
} else {
|
||||||
// Compare public keys
|
// Compare public keys
|
||||||
if result.PublicKey.X.Cmp(firstPubKey.X) != 0 ||
|
if result.PublicKey.X.Cmp(firstPubKey.X) != 0 ||
|
||||||
result.PublicKey.Y.Cmp(firstPubKey.Y) != 0 {
|
result.PublicKey.Y.Cmp(firstPubKey.Y) != 0 {
|
||||||
t.Errorf("Party %d has different public key", i)
|
t.Errorf("Party %d has different public key", i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("Keygen successful: %d-of-%d, public key X: %s",
|
t.Logf("Keygen successful: %d-of-%d, public key X: %s",
|
||||||
tt.threshold, tt.totalParties, firstPubKey.X.Text(16)[:16]+"...")
|
tt.threshold, tt.totalParties, firstPubKey.X.Text(16)[:16]+"...")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRunLocalSigning tests the local signing functionality
|
// TestRunLocalSigning tests the local signing functionality
|
||||||
func TestRunLocalSigning(t *testing.T) {
|
func TestRunLocalSigning(t *testing.T) {
|
||||||
// First run keygen to get key shares
|
// First run keygen to get key shares
|
||||||
threshold := 2
|
threshold := 2
|
||||||
totalParties := 3
|
totalParties := 3
|
||||||
|
|
||||||
keygenResults, err := RunLocalKeygen(threshold, totalParties)
|
keygenResults, err := RunLocalKeygen(threshold, totalParties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Keygen failed: %v", err)
|
t.Fatalf("Keygen failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create message hash
|
// Create message hash
|
||||||
message := []byte("Hello, MPC signing!")
|
message := []byte("Hello, MPC signing!")
|
||||||
messageHash := sha256.Sum256(message)
|
messageHash := sha256.Sum256(message)
|
||||||
|
|
||||||
// Run signing
|
// Run signing
|
||||||
signResult, err := RunLocalSigning(threshold, keygenResults, messageHash[:])
|
signResult, err := RunLocalSigning(threshold, keygenResults, messageHash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Signing failed: %v", err)
|
t.Fatalf("Signing failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify signature
|
// Verify signature
|
||||||
if signResult == nil {
|
if signResult == nil {
|
||||||
t.Fatal("Sign result is nil")
|
t.Fatal("Sign result is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(signResult.Signature) != 64 {
|
if len(signResult.Signature) != 64 {
|
||||||
t.Errorf("Expected 64-byte signature, got %d bytes", len(signResult.Signature))
|
t.Errorf("Expected 64-byte signature, got %d bytes", len(signResult.Signature))
|
||||||
}
|
}
|
||||||
|
|
||||||
if signResult.R == nil || signResult.S == nil {
|
if signResult.R == nil || signResult.S == nil {
|
||||||
t.Error("R or S is nil")
|
t.Error("R or S is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify signature using the public key
|
// Verify signature using the public key
|
||||||
pubKey := keygenResults[0].PublicKey
|
pubKey := keygenResults[0].PublicKey
|
||||||
valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S)
|
valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S)
|
||||||
if !valid {
|
if !valid {
|
||||||
t.Error("Signature verification failed")
|
t.Error("Signature verification failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("Signing successful: R=%s..., S=%s...",
|
t.Logf("Signing successful: R=%s..., S=%s...",
|
||||||
signResult.R.Text(16)[:16], signResult.S.Text(16)[:16])
|
signResult.R.Text(16)[:16], signResult.S.Text(16)[:16])
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestMultipleSigning tests signing multiple messages with the same keys
|
// TestMultipleSigning tests signing multiple messages with the same keys
|
||||||
func TestMultipleSigning(t *testing.T) {
|
func TestMultipleSigning(t *testing.T) {
|
||||||
// Run keygen
|
// Run keygen
|
||||||
threshold := 2
|
threshold := 2
|
||||||
totalParties := 3
|
totalParties := 3
|
||||||
|
|
||||||
keygenResults, err := RunLocalKeygen(threshold, totalParties)
|
keygenResults, err := RunLocalKeygen(threshold, totalParties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Keygen failed: %v", err)
|
t.Fatalf("Keygen failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
messages := []string{
|
messages := []string{
|
||||||
"First message",
|
"First message",
|
||||||
"Second message",
|
"Second message",
|
||||||
"Third message",
|
"Third message",
|
||||||
}
|
}
|
||||||
|
|
||||||
pubKey := keygenResults[0].PublicKey
|
pubKey := keygenResults[0].PublicKey
|
||||||
|
|
||||||
for i, msg := range messages {
|
for i, msg := range messages {
|
||||||
messageHash := sha256.Sum256([]byte(msg))
|
messageHash := sha256.Sum256([]byte(msg))
|
||||||
signResult, err := RunLocalSigning(threshold, keygenResults, messageHash[:])
|
signResult, err := RunLocalSigning(threshold, keygenResults, messageHash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Signing message %d failed: %v", i, err)
|
t.Errorf("Signing message %d failed: %v", i, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S)
|
valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S)
|
||||||
if !valid {
|
if !valid {
|
||||||
t.Errorf("Signature %d verification failed", i)
|
t.Errorf("Signature %d verification failed", i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestSigningWithSubsetOfParties tests signing with a subset of parties
|
// TestSigningWithSubsetOfParties tests signing with a subset of parties
|
||||||
// In tss-lib, threshold `t` means `t+1` parties are needed to sign.
|
// In tss-lib, threshold `t` means `t+1` parties are needed to sign.
|
||||||
// For a 2-of-3 scheme (2 signers needed), we use threshold=1 (1+1=2).
|
// For a 2-of-3 scheme (2 signers needed), we use threshold=1 (1+1=2).
|
||||||
func TestSigningWithSubsetOfParties(t *testing.T) {
|
func TestSigningWithSubsetOfParties(t *testing.T) {
|
||||||
// For a 2-of-3 scheme in tss-lib:
|
// For a 2-of-3 scheme in tss-lib:
|
||||||
// - totalParties (n) = 3
|
// - totalParties (n) = 3
|
||||||
// - threshold (t) = 1 (meaning t+1=2 parties are required to sign)
|
// - threshold (t) = 1 (meaning t+1=2 parties are required to sign)
|
||||||
threshold := 1
|
threshold := 1
|
||||||
totalParties := 3
|
totalParties := 3
|
||||||
|
|
||||||
keygenResults, err := RunLocalKeygen(threshold, totalParties)
|
keygenResults, err := RunLocalKeygen(threshold, totalParties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Keygen failed: %v", err)
|
t.Fatalf("Keygen failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sign with only 2 parties (party 0 and party 1) - this should work with t=1
|
// Sign with only 2 parties (party 0 and party 1) - this should work with t=1
|
||||||
signers := []*LocalKeygenResult{
|
signers := []*LocalKeygenResult{
|
||||||
keygenResults[0],
|
keygenResults[0],
|
||||||
keygenResults[1],
|
keygenResults[1],
|
||||||
}
|
}
|
||||||
|
|
||||||
message := []byte("Threshold signing test")
|
message := []byte("Threshold signing test")
|
||||||
messageHash := sha256.Sum256(message)
|
messageHash := sha256.Sum256(message)
|
||||||
|
|
||||||
signResult, err := RunLocalSigning(threshold, signers, messageHash[:])
|
signResult, err := RunLocalSigning(threshold, signers, messageHash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Signing with subset failed: %v", err)
|
t.Fatalf("Signing with subset failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify signature
|
// Verify signature
|
||||||
pubKey := keygenResults[0].PublicKey
|
pubKey := keygenResults[0].PublicKey
|
||||||
valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S)
|
valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S)
|
||||||
if !valid {
|
if !valid {
|
||||||
t.Error("Signature verification failed for subset signing")
|
t.Error("Signature verification failed for subset signing")
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("Subset signing (2-of-3) successful with threshold=1")
|
t.Log("Subset signing (2-of-3) successful with threshold=1")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestSigningWithDifferentSubsets tests signing with different party combinations
|
// TestSigningWithDifferentSubsets tests signing with different party combinations
|
||||||
// In tss-lib, threshold `t` means `t+1` parties are needed to sign.
|
// In tss-lib, threshold `t` means `t+1` parties are needed to sign.
|
||||||
// For a 2-of-3 scheme (2 signers needed), we use threshold=1.
|
// For a 2-of-3 scheme (2 signers needed), we use threshold=1.
|
||||||
func TestSigningWithDifferentSubsets(t *testing.T) {
|
func TestSigningWithDifferentSubsets(t *testing.T) {
|
||||||
// For 2-of-3 in tss-lib terminology: threshold=1 means t+1=2 signers needed
|
// For 2-of-3 in tss-lib terminology: threshold=1 means t+1=2 signers needed
|
||||||
threshold := 1
|
threshold := 1
|
||||||
totalParties := 3
|
totalParties := 3
|
||||||
|
|
||||||
keygenResults, err := RunLocalKeygen(threshold, totalParties)
|
keygenResults, err := RunLocalKeygen(threshold, totalParties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Keygen failed: %v", err)
|
t.Fatalf("Keygen failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pubKey := keygenResults[0].PublicKey
|
pubKey := keygenResults[0].PublicKey
|
||||||
|
|
||||||
// Test different combinations of 2 parties (the minimum required with t=1)
|
// Test different combinations of 2 parties (the minimum required with t=1)
|
||||||
combinations := [][]*LocalKeygenResult{
|
combinations := [][]*LocalKeygenResult{
|
||||||
{keygenResults[0], keygenResults[1]}, // parties 0,1
|
{keygenResults[0], keygenResults[1]}, // parties 0,1
|
||||||
{keygenResults[0], keygenResults[2]}, // parties 0,2
|
{keygenResults[0], keygenResults[2]}, // parties 0,2
|
||||||
{keygenResults[1], keygenResults[2]}, // parties 1,2
|
{keygenResults[1], keygenResults[2]}, // parties 1,2
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, signers := range combinations {
|
for i, signers := range combinations {
|
||||||
message := []byte("Test message " + string(rune('A'+i)))
|
message := []byte("Test message " + string(rune('A'+i)))
|
||||||
messageHash := sha256.Sum256(message)
|
messageHash := sha256.Sum256(message)
|
||||||
|
|
||||||
signResult, err := RunLocalSigning(threshold, signers, messageHash[:])
|
signResult, err := RunLocalSigning(threshold, signers, messageHash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Signing with combination %d failed: %v", i, err)
|
t.Errorf("Signing with combination %d failed: %v", i, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S)
|
valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S)
|
||||||
if !valid {
|
if !valid {
|
||||||
t.Errorf("Signature verification failed for combination %d", i)
|
t.Errorf("Signature verification failed for combination %d", i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("All subset combinations successful")
|
t.Log("All subset combinations successful")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestKeygenResultConsistency tests that all parties produce consistent results
|
// TestKeygenResultConsistency tests that all parties produce consistent results
|
||||||
func TestKeygenResultConsistency(t *testing.T) {
|
func TestKeygenResultConsistency(t *testing.T) {
|
||||||
threshold := 2
|
threshold := 2
|
||||||
totalParties := 3
|
totalParties := 3
|
||||||
|
|
||||||
results, err := RunLocalKeygen(threshold, totalParties)
|
results, err := RunLocalKeygen(threshold, totalParties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Keygen failed: %v", err)
|
t.Fatalf("Keygen failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// All parties should have the same ECDSAPub
|
// All parties should have the same ECDSAPub
|
||||||
var refX, refY *big.Int
|
var refX, refY *big.Int
|
||||||
for i, result := range results {
|
for i, result := range results {
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
refX = result.SaveData.ECDSAPub.X()
|
refX = result.SaveData.ECDSAPub.X()
|
||||||
refY = result.SaveData.ECDSAPub.Y()
|
refY = result.SaveData.ECDSAPub.Y()
|
||||||
} else {
|
} else {
|
||||||
if result.SaveData.ECDSAPub.X().Cmp(refX) != 0 {
|
if result.SaveData.ECDSAPub.X().Cmp(refX) != 0 {
|
||||||
t.Errorf("Party %d X coordinate mismatch", i)
|
t.Errorf("Party %d X coordinate mismatch", i)
|
||||||
}
|
}
|
||||||
if result.SaveData.ECDSAPub.Y().Cmp(refY) != 0 {
|
if result.SaveData.ECDSAPub.Y().Cmp(refY) != 0 {
|
||||||
t.Errorf("Party %d Y coordinate mismatch", i)
|
t.Errorf("Party %d Y coordinate mismatch", i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestSignatureRecovery tests that the recovery ID allows public key recovery
|
// TestSignatureRecovery tests that the recovery ID allows public key recovery
|
||||||
func TestSignatureRecovery(t *testing.T) {
|
func TestSignatureRecovery(t *testing.T) {
|
||||||
threshold := 2
|
threshold := 2
|
||||||
totalParties := 3
|
totalParties := 3
|
||||||
|
|
||||||
keygenResults, err := RunLocalKeygen(threshold, totalParties)
|
keygenResults, err := RunLocalKeygen(threshold, totalParties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Keygen failed: %v", err)
|
t.Fatalf("Keygen failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
message := []byte("Recovery test message")
|
message := []byte("Recovery test message")
|
||||||
messageHash := sha256.Sum256(message)
|
messageHash := sha256.Sum256(message)
|
||||||
|
|
||||||
signResult, err := RunLocalSigning(threshold, keygenResults, messageHash[:])
|
signResult, err := RunLocalSigning(threshold, keygenResults, messageHash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Signing failed: %v", err)
|
t.Fatalf("Signing failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the recovery ID is valid (0-3)
|
// Verify the recovery ID is valid (0-3)
|
||||||
if signResult.RecoveryID < 0 || signResult.RecoveryID > 3 {
|
if signResult.RecoveryID < 0 || signResult.RecoveryID > 3 {
|
||||||
t.Errorf("Invalid recovery ID: %d", signResult.RecoveryID)
|
t.Errorf("Invalid recovery ID: %d", signResult.RecoveryID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify we can create a btcec signature and verify it
|
// Verify we can create a btcec signature and verify it
|
||||||
r := new(btcec.ModNScalar)
|
r := new(btcec.ModNScalar)
|
||||||
r.SetByteSlice(signResult.R.Bytes())
|
r.SetByteSlice(signResult.R.Bytes())
|
||||||
s := new(btcec.ModNScalar)
|
s := new(btcec.ModNScalar)
|
||||||
s.SetByteSlice(signResult.S.Bytes())
|
s.SetByteSlice(signResult.S.Bytes())
|
||||||
|
|
||||||
btcSig := ecdsa.NewSignature(r, s)
|
btcSig := ecdsa.NewSignature(r, s)
|
||||||
|
|
||||||
// Convert public key to btcec format
|
// Convert public key to btcec format
|
||||||
originalPub := keygenResults[0].PublicKey
|
originalPub := keygenResults[0].PublicKey
|
||||||
btcPubKey, err := btcec.ParsePubKey(append([]byte{0x04}, append(originalPub.X.Bytes(), originalPub.Y.Bytes()...)...))
|
btcPubKey, err := btcec.ParsePubKey(append([]byte{0x04}, append(originalPub.X.Bytes(), originalPub.Y.Bytes()...)...))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("Failed to parse public key: %v", err)
|
t.Logf("Failed to parse public key: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the signature
|
// Verify the signature
|
||||||
verified := btcSig.Verify(messageHash[:], btcPubKey)
|
verified := btcSig.Verify(messageHash[:], btcPubKey)
|
||||||
if !verified {
|
if !verified {
|
||||||
t.Error("btcec signature verification failed")
|
t.Error("btcec signature verification failed")
|
||||||
} else {
|
} else {
|
||||||
t.Log("btcec signature verification successful")
|
t.Log("btcec signature verification successful")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestNewKeygenSession tests creating a new keygen session
|
// TestNewKeygenSession tests creating a new keygen session
|
||||||
func TestNewKeygenSession(t *testing.T) {
|
func TestNewKeygenSession(t *testing.T) {
|
||||||
config := KeygenConfig{
|
config := KeygenConfig{
|
||||||
Threshold: 2,
|
Threshold: 2,
|
||||||
TotalParties: 3,
|
TotalParties: 3,
|
||||||
}
|
}
|
||||||
|
|
||||||
selfParty := KeygenParty{PartyID: "party-0", PartyIndex: 0}
|
selfParty := KeygenParty{PartyID: "party-0", PartyIndex: 0}
|
||||||
allParties := []KeygenParty{
|
allParties := []KeygenParty{
|
||||||
{PartyID: "party-0", PartyIndex: 0},
|
{PartyID: "party-0", PartyIndex: 0},
|
||||||
{PartyID: "party-1", PartyIndex: 1},
|
{PartyID: "party-1", PartyIndex: 1},
|
||||||
{PartyID: "party-2", PartyIndex: 2},
|
{PartyID: "party-2", PartyIndex: 2},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a mock message handler
|
// Create a mock message handler
|
||||||
handler := &mockMessageHandler{
|
handler := &mockMessageHandler{
|
||||||
msgCh: make(chan *ReceivedMessage, 100),
|
msgCh: make(chan *ReceivedMessage, 100),
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := NewKeygenSession(config, selfParty, allParties, handler)
|
session, err := NewKeygenSession(config, selfParty, allParties, handler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create keygen session: %v", err)
|
t.Fatalf("Failed to create keygen session: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if session == nil {
|
if session == nil {
|
||||||
t.Fatal("Session is nil")
|
t.Fatal("Session is nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestNewKeygenSessionValidation tests validation in NewKeygenSession
|
// TestNewKeygenSessionValidation tests validation in NewKeygenSession
|
||||||
func TestNewKeygenSessionValidation(t *testing.T) {
|
func TestNewKeygenSessionValidation(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
config KeygenConfig
|
config KeygenConfig
|
||||||
selfParty KeygenParty
|
selfParty KeygenParty
|
||||||
allParties []KeygenParty
|
allParties []KeygenParty
|
||||||
wantErr bool
|
wantErr bool
|
||||||
expectedErr error
|
expectedErr error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "invalid party count",
|
name: "invalid party count",
|
||||||
config: KeygenConfig{
|
config: KeygenConfig{
|
||||||
Threshold: 2,
|
Threshold: 2,
|
||||||
TotalParties: 1,
|
TotalParties: 1,
|
||||||
},
|
},
|
||||||
selfParty: KeygenParty{PartyID: "party-0", PartyIndex: 0},
|
selfParty: KeygenParty{PartyID: "party-0", PartyIndex: 0},
|
||||||
allParties: []KeygenParty{{PartyID: "party-0", PartyIndex: 0}},
|
allParties: []KeygenParty{{PartyID: "party-0", PartyIndex: 0}},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
expectedErr: ErrInvalidPartyCount,
|
expectedErr: ErrInvalidPartyCount,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid threshold - zero",
|
name: "invalid threshold - zero",
|
||||||
config: KeygenConfig{
|
config: KeygenConfig{
|
||||||
Threshold: 0,
|
Threshold: 0,
|
||||||
TotalParties: 3,
|
TotalParties: 3,
|
||||||
},
|
},
|
||||||
selfParty: KeygenParty{PartyID: "party-0", PartyIndex: 0},
|
selfParty: KeygenParty{PartyID: "party-0", PartyIndex: 0},
|
||||||
allParties: []KeygenParty{{PartyID: "party-0", PartyIndex: 0}, {PartyID: "party-1", PartyIndex: 1}, {PartyID: "party-2", PartyIndex: 2}},
|
allParties: []KeygenParty{{PartyID: "party-0", PartyIndex: 0}, {PartyID: "party-1", PartyIndex: 1}, {PartyID: "party-2", PartyIndex: 2}},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
expectedErr: ErrInvalidThreshold,
|
expectedErr: ErrInvalidThreshold,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mismatched party count",
|
name: "mismatched party count",
|
||||||
config: KeygenConfig{
|
config: KeygenConfig{
|
||||||
Threshold: 2,
|
Threshold: 2,
|
||||||
TotalParties: 3,
|
TotalParties: 3,
|
||||||
},
|
},
|
||||||
selfParty: KeygenParty{PartyID: "party-0", PartyIndex: 0},
|
selfParty: KeygenParty{PartyID: "party-0", PartyIndex: 0},
|
||||||
allParties: []KeygenParty{{PartyID: "party-0", PartyIndex: 0}, {PartyID: "party-1", PartyIndex: 1}},
|
allParties: []KeygenParty{{PartyID: "party-0", PartyIndex: 0}, {PartyID: "party-1", PartyIndex: 1}},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
expectedErr: ErrInvalidPartyCount,
|
expectedErr: ErrInvalidPartyCount,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
handler := &mockMessageHandler{msgCh: make(chan *ReceivedMessage)}
|
handler := &mockMessageHandler{msgCh: make(chan *ReceivedMessage)}
|
||||||
_, err := NewKeygenSession(tt.config, tt.selfParty, tt.allParties, handler)
|
_, err := NewKeygenSession(tt.config, tt.selfParty, tt.allParties, handler)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("NewKeygenSession() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("NewKeygenSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
}
|
}
|
||||||
if tt.expectedErr != nil && err != tt.expectedErr {
|
if tt.expectedErr != nil && err != tt.expectedErr {
|
||||||
t.Errorf("Expected error %v, got %v", tt.expectedErr, err)
|
t.Errorf("Expected error %v, got %v", tt.expectedErr, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// mockMessageHandler is a mock implementation of MessageHandler for testing
|
// mockMessageHandler is a mock implementation of MessageHandler for testing
|
||||||
type mockMessageHandler struct {
|
type mockMessageHandler struct {
|
||||||
msgCh chan *ReceivedMessage
|
msgCh chan *ReceivedMessage
|
||||||
sentMsgs []sentMessage
|
sentMsgs []sentMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
type sentMessage struct {
|
type sentMessage struct {
|
||||||
isBroadcast bool
|
isBroadcast bool
|
||||||
toParties []string
|
toParties []string
|
||||||
msgBytes []byte
|
msgBytes []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error {
|
func (m *mockMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error {
|
||||||
m.sentMsgs = append(m.sentMsgs, sentMessage{
|
m.sentMsgs = append(m.sentMsgs, sentMessage{
|
||||||
isBroadcast: isBroadcast,
|
isBroadcast: isBroadcast,
|
||||||
toParties: toParties,
|
toParties: toParties,
|
||||||
msgBytes: msgBytes,
|
msgBytes: msgBytes,
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockMessageHandler) ReceiveMessages() <-chan *ReceivedMessage {
|
func (m *mockMessageHandler) ReceiveMessages() <-chan *ReceivedMessage {
|
||||||
return m.msgCh
|
return m.msgCh
|
||||||
}
|
}
|
||||||
|
|
||||||
// BenchmarkKeygen benchmarks the keygen operation
|
// BenchmarkKeygen benchmarks the keygen operation
|
||||||
func BenchmarkKeygen2of3(b *testing.B) {
|
func BenchmarkKeygen2of3(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_, err := RunLocalKeygen(2, 3)
|
_, err := RunLocalKeygen(2, 3)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatalf("Keygen failed: %v", err)
|
b.Fatalf("Keygen failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BenchmarkSigning benchmarks the signing operation
|
// BenchmarkSigning benchmarks the signing operation
|
||||||
func BenchmarkSigning2of3(b *testing.B) {
|
func BenchmarkSigning2of3(b *testing.B) {
|
||||||
// Setup: run keygen once
|
// Setup: run keygen once
|
||||||
keygenResults, err := RunLocalKeygen(2, 3)
|
keygenResults, err := RunLocalKeygen(2, 3)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatalf("Keygen failed: %v", err)
|
b.Fatalf("Keygen failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
message := []byte("Benchmark signing message")
|
message := []byte("Benchmark signing message")
|
||||||
messageHash := sha256.Sum256(message)
|
messageHash := sha256.Sum256(message)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_, err := RunLocalSigning(2, keygenResults, messageHash[:])
|
_, err := RunLocalSigning(2, keygenResults, messageHash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatalf("Signing failed: %v", err)
|
b.Fatalf("Signing failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,239 +1,239 @@
|
||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"math/big"
|
"math/big"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GenerateID generates a new UUID
|
// GenerateID generates a new UUID
|
||||||
func GenerateID() uuid.UUID {
|
func GenerateID() uuid.UUID {
|
||||||
return uuid.New()
|
return uuid.New()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseUUID parses a string to UUID
|
// ParseUUID parses a string to UUID
|
||||||
func ParseUUID(s string) (uuid.UUID, error) {
|
func ParseUUID(s string) (uuid.UUID, error) {
|
||||||
return uuid.Parse(s)
|
return uuid.Parse(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MustParseUUID parses a string to UUID, panics on error
|
// MustParseUUID parses a string to UUID, panics on error
|
||||||
func MustParseUUID(s string) uuid.UUID {
|
func MustParseUUID(s string) uuid.UUID {
|
||||||
id, err := uuid.Parse(s)
|
id, err := uuid.Parse(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
return id
|
return id
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidUUID checks if a string is a valid UUID
|
// IsValidUUID checks if a string is a valid UUID
|
||||||
func IsValidUUID(s string) bool {
|
func IsValidUUID(s string) bool {
|
||||||
_, err := uuid.Parse(s)
|
_, err := uuid.Parse(s)
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToJSON converts an interface to JSON bytes
|
// ToJSON converts an interface to JSON bytes
|
||||||
func ToJSON(v interface{}) ([]byte, error) {
|
func ToJSON(v interface{}) ([]byte, error) {
|
||||||
return json.Marshal(v)
|
return json.Marshal(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FromJSON converts JSON bytes to an interface
|
// FromJSON converts JSON bytes to an interface
|
||||||
func FromJSON(data []byte, v interface{}) error {
|
func FromJSON(data []byte, v interface{}) error {
|
||||||
return json.Unmarshal(data, v)
|
return json.Unmarshal(data, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NowUTC returns the current UTC time
|
// NowUTC returns the current UTC time
|
||||||
func NowUTC() time.Time {
|
func NowUTC() time.Time {
|
||||||
return time.Now().UTC()
|
return time.Now().UTC()
|
||||||
}
|
}
|
||||||
|
|
||||||
// TimePtr returns a pointer to the time
|
// TimePtr returns a pointer to the time
|
||||||
func TimePtr(t time.Time) *time.Time {
|
func TimePtr(t time.Time) *time.Time {
|
||||||
return &t
|
return &t
|
||||||
}
|
}
|
||||||
|
|
||||||
// NowPtr returns a pointer to the current time
|
// NowPtr returns a pointer to the current time
|
||||||
func NowPtr() *time.Time {
|
func NowPtr() *time.Time {
|
||||||
now := NowUTC()
|
now := NowUTC()
|
||||||
return &now
|
return &now
|
||||||
}
|
}
|
||||||
|
|
||||||
// BigIntToBytes converts a big.Int to bytes (32 bytes, left-padded)
|
// BigIntToBytes converts a big.Int to bytes (32 bytes, left-padded)
|
||||||
func BigIntToBytes(n *big.Int) []byte {
|
func BigIntToBytes(n *big.Int) []byte {
|
||||||
if n == nil {
|
if n == nil {
|
||||||
return make([]byte, 32)
|
return make([]byte, 32)
|
||||||
}
|
}
|
||||||
b := n.Bytes()
|
b := n.Bytes()
|
||||||
if len(b) > 32 {
|
if len(b) > 32 {
|
||||||
return b[:32]
|
return b[:32]
|
||||||
}
|
}
|
||||||
if len(b) < 32 {
|
if len(b) < 32 {
|
||||||
result := make([]byte, 32)
|
result := make([]byte, 32)
|
||||||
copy(result[32-len(b):], b)
|
copy(result[32-len(b):], b)
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// BytesToBigInt converts bytes to big.Int
|
// BytesToBigInt converts bytes to big.Int
|
||||||
func BytesToBigInt(b []byte) *big.Int {
|
func BytesToBigInt(b []byte) *big.Int {
|
||||||
return new(big.Int).SetBytes(b)
|
return new(big.Int).SetBytes(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringSliceContains checks if a string slice contains a value
|
// StringSliceContains checks if a string slice contains a value
|
||||||
func StringSliceContains(slice []string, value string) bool {
|
func StringSliceContains(slice []string, value string) bool {
|
||||||
for _, s := range slice {
|
for _, s := range slice {
|
||||||
if s == value {
|
if s == value {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringSliceRemove removes a value from a string slice
|
// StringSliceRemove removes a value from a string slice
|
||||||
func StringSliceRemove(slice []string, value string) []string {
|
func StringSliceRemove(slice []string, value string) []string {
|
||||||
result := make([]string, 0, len(slice))
|
result := make([]string, 0, len(slice))
|
||||||
for _, s := range slice {
|
for _, s := range slice {
|
||||||
if s != value {
|
if s != value {
|
||||||
result = append(result, s)
|
result = append(result, s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// UniqueStrings returns unique strings from a slice
|
// UniqueStrings returns unique strings from a slice
|
||||||
func UniqueStrings(slice []string) []string {
|
func UniqueStrings(slice []string) []string {
|
||||||
seen := make(map[string]struct{})
|
seen := make(map[string]struct{})
|
||||||
result := make([]string, 0, len(slice))
|
result := make([]string, 0, len(slice))
|
||||||
for _, s := range slice {
|
for _, s := range slice {
|
||||||
if _, ok := seen[s]; !ok {
|
if _, ok := seen[s]; !ok {
|
||||||
seen[s] = struct{}{}
|
seen[s] = struct{}{}
|
||||||
result = append(result, s)
|
result = append(result, s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// TruncateString truncates a string to max length
|
// TruncateString truncates a string to max length
|
||||||
func TruncateString(s string, maxLen int) string {
|
func TruncateString(s string, maxLen int) string {
|
||||||
if len(s) <= maxLen {
|
if len(s) <= maxLen {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
return s[:maxLen]
|
return s[:maxLen]
|
||||||
}
|
}
|
||||||
|
|
||||||
// SafeString returns an empty string if the pointer is nil
|
// SafeString returns an empty string if the pointer is nil
|
||||||
func SafeString(s *string) string {
|
func SafeString(s *string) string {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return *s
|
return *s
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringPtr returns a pointer to the string
|
// StringPtr returns a pointer to the string
|
||||||
func StringPtr(s string) *string {
|
func StringPtr(s string) *string {
|
||||||
return &s
|
return &s
|
||||||
}
|
}
|
||||||
|
|
||||||
// IntPtr returns a pointer to the int
|
// IntPtr returns a pointer to the int
|
||||||
func IntPtr(i int) *int {
|
func IntPtr(i int) *int {
|
||||||
return &i
|
return &i
|
||||||
}
|
}
|
||||||
|
|
||||||
// BoolPtr returns a pointer to the bool
|
// BoolPtr returns a pointer to the bool
|
||||||
func BoolPtr(b bool) *bool {
|
func BoolPtr(b bool) *bool {
|
||||||
return &b
|
return &b
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsZero checks if a value is zero/empty
|
// IsZero checks if a value is zero/empty
|
||||||
func IsZero(v interface{}) bool {
|
func IsZero(v interface{}) bool {
|
||||||
return reflect.ValueOf(v).IsZero()
|
return reflect.ValueOf(v).IsZero()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Coalesce returns the first non-zero value
|
// Coalesce returns the first non-zero value
|
||||||
func Coalesce[T comparable](values ...T) T {
|
func Coalesce[T comparable](values ...T) T {
|
||||||
var zero T
|
var zero T
|
||||||
for _, v := range values {
|
for _, v := range values {
|
||||||
if v != zero {
|
if v != zero {
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return zero
|
return zero
|
||||||
}
|
}
|
||||||
|
|
||||||
// MapKeys returns the keys of a map
|
// MapKeys returns the keys of a map
|
||||||
func MapKeys[K comparable, V any](m map[K]V) []K {
|
func MapKeys[K comparable, V any](m map[K]V) []K {
|
||||||
keys := make([]K, 0, len(m))
|
keys := make([]K, 0, len(m))
|
||||||
for k := range m {
|
for k := range m {
|
||||||
keys = append(keys, k)
|
keys = append(keys, k)
|
||||||
}
|
}
|
||||||
return keys
|
return keys
|
||||||
}
|
}
|
||||||
|
|
||||||
// MapValues returns the values of a map
|
// MapValues returns the values of a map
|
||||||
func MapValues[K comparable, V any](m map[K]V) []V {
|
func MapValues[K comparable, V any](m map[K]V) []V {
|
||||||
values := make([]V, 0, len(m))
|
values := make([]V, 0, len(m))
|
||||||
for _, v := range m {
|
for _, v := range m {
|
||||||
values = append(values, v)
|
values = append(values, v)
|
||||||
}
|
}
|
||||||
return values
|
return values
|
||||||
}
|
}
|
||||||
|
|
||||||
// Min returns the minimum of two values
|
// Min returns the minimum of two values
|
||||||
func Min[T ~int | ~int64 | ~float64](a, b T) T {
|
func Min[T ~int | ~int64 | ~float64](a, b T) T {
|
||||||
if a < b {
|
if a < b {
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// Max returns the maximum of two values
|
// Max returns the maximum of two values
|
||||||
func Max[T ~int | ~int64 | ~float64](a, b T) T {
|
func Max[T ~int | ~int64 | ~float64](a, b T) T {
|
||||||
if a > b {
|
if a > b {
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clamp clamps a value between min and max
|
// Clamp clamps a value between min and max
|
||||||
func Clamp[T ~int | ~int64 | ~float64](value, min, max T) T {
|
func Clamp[T ~int | ~int64 | ~float64](value, min, max T) T {
|
||||||
if value < min {
|
if value < min {
|
||||||
return min
|
return min
|
||||||
}
|
}
|
||||||
if value > max {
|
if value > max {
|
||||||
return max
|
return max
|
||||||
}
|
}
|
||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
// ContextWithTimeout creates a context with timeout
|
// ContextWithTimeout creates a context with timeout
|
||||||
func ContextWithTimeout(timeout time.Duration) (context.Context, context.CancelFunc) {
|
func ContextWithTimeout(timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||||
return context.WithTimeout(context.Background(), timeout)
|
return context.WithTimeout(context.Background(), timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MaskString masks a string showing only first and last n characters
|
// MaskString masks a string showing only first and last n characters
|
||||||
func MaskString(s string, showChars int) string {
|
func MaskString(s string, showChars int) string {
|
||||||
if len(s) <= showChars*2 {
|
if len(s) <= showChars*2 {
|
||||||
return strings.Repeat("*", len(s))
|
return strings.Repeat("*", len(s))
|
||||||
}
|
}
|
||||||
return s[:showChars] + strings.Repeat("*", len(s)-showChars*2) + s[len(s)-showChars:]
|
return s[:showChars] + strings.Repeat("*", len(s)-showChars*2) + s[len(s)-showChars:]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retry executes a function with retries
|
// Retry executes a function with retries
|
||||||
func Retry(attempts int, sleep time.Duration, f func() error) error {
|
func Retry(attempts int, sleep time.Duration, f func() error) error {
|
||||||
var err error
|
var err error
|
||||||
for i := 0; i < attempts; i++ {
|
for i := 0; i < attempts; i++ {
|
||||||
if err = f(); err == nil {
|
if err = f(); err == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if i < attempts-1 {
|
if i < attempts-1 {
|
||||||
time.Sleep(sleep)
|
time.Sleep(sleep)
|
||||||
sleep *= 2 // Exponential backoff
|
sleep *= 2 // Exponential backoff
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,345 +1,345 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
#
|
#
|
||||||
# Transparent Proxy Script for Gateway (192.168.1.100)
|
# Transparent Proxy Script for Gateway (192.168.1.100)
|
||||||
# Routes traffic from LAN clients (192.168.1.111) through Clash proxy
|
# Routes traffic from LAN clients (192.168.1.111) through Clash proxy
|
||||||
#
|
#
|
||||||
# Usage:
|
# Usage:
|
||||||
# ./tproxy.sh on # Enable transparent proxy
|
# ./tproxy.sh on # Enable transparent proxy
|
||||||
# ./tproxy.sh off # Disable transparent proxy
|
# ./tproxy.sh off # Disable transparent proxy
|
||||||
# ./tproxy.sh status # Check status
|
# ./tproxy.sh status # Check status
|
||||||
#
|
#
|
||||||
# Prerequisites:
|
# Prerequisites:
|
||||||
# - Clash running with allow-lan: true
|
# - Clash running with allow-lan: true
|
||||||
# - This machine is the gateway for 192.168.1.111
|
# - This machine is the gateway for 192.168.1.111
|
||||||
#
|
#
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Configuration
|
# Configuration
|
||||||
# ============================================
|
# ============================================
|
||||||
# Clash proxy ports
|
# Clash proxy ports
|
||||||
CLASH_HTTP_PORT="${CLASH_HTTP_PORT:-7890}"
|
CLASH_HTTP_PORT="${CLASH_HTTP_PORT:-7890}"
|
||||||
CLASH_SOCKS_PORT="${CLASH_SOCKS_PORT:-7891}"
|
CLASH_SOCKS_PORT="${CLASH_SOCKS_PORT:-7891}"
|
||||||
CLASH_REDIR_PORT="${CLASH_REDIR_PORT:-7892}"
|
CLASH_REDIR_PORT="${CLASH_REDIR_PORT:-7892}"
|
||||||
CLASH_TPROXY_PORT="${CLASH_TPROXY_PORT:-7893}"
|
CLASH_TPROXY_PORT="${CLASH_TPROXY_PORT:-7893}"
|
||||||
CLASH_DNS_PORT="${CLASH_DNS_PORT:-1053}"
|
CLASH_DNS_PORT="${CLASH_DNS_PORT:-1053}"
|
||||||
|
|
||||||
# Network configuration
|
# Network configuration
|
||||||
LAN_INTERFACE="${LAN_INTERFACE:-eth0}"
|
LAN_INTERFACE="${LAN_INTERFACE:-eth0}"
|
||||||
LAN_SUBNET="${LAN_SUBNET:-192.168.1.0/24}"
|
LAN_SUBNET="${LAN_SUBNET:-192.168.1.0/24}"
|
||||||
GATEWAY_IP="${GATEWAY_IP:-192.168.1.100}"
|
GATEWAY_IP="${GATEWAY_IP:-192.168.1.100}"
|
||||||
|
|
||||||
# Clients to proxy (space-separated)
|
# Clients to proxy (space-separated)
|
||||||
PROXY_CLIENTS="${PROXY_CLIENTS:-192.168.1.111}"
|
PROXY_CLIENTS="${PROXY_CLIENTS:-192.168.1.111}"
|
||||||
|
|
||||||
# Bypass destinations (don't proxy these)
|
# Bypass destinations (don't proxy these)
|
||||||
BYPASS_IPS="127.0.0.0/8 10.0.0.0/8 172.16.0.0/12 192.168.0.0/16 224.0.0.0/4 240.0.0.0/4"
|
BYPASS_IPS="127.0.0.0/8 10.0.0.0/8 172.16.0.0/12 192.168.0.0/16 224.0.0.0/4 240.0.0.0/4"
|
||||||
|
|
||||||
# iptables chain name
|
# iptables chain name
|
||||||
CHAIN_NAME="CLASH_TPROXY"
|
CHAIN_NAME="CLASH_TPROXY"
|
||||||
|
|
||||||
# Colors
|
# Colors
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
GREEN='\033[0;32m'
|
GREEN='\033[0;32m'
|
||||||
YELLOW='\033[1;33m'
|
YELLOW='\033[1;33m'
|
||||||
NC='\033[0m'
|
NC='\033[0m'
|
||||||
|
|
||||||
log_info() { echo -e "${GREEN}[INFO]${NC} $1"; }
|
log_info() { echo -e "${GREEN}[INFO]${NC} $1"; }
|
||||||
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
||||||
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Check Prerequisites
|
# Check Prerequisites
|
||||||
# ============================================
|
# ============================================
|
||||||
check_root() {
|
check_root() {
|
||||||
if [ "$EUID" -ne 0 ]; then
|
if [ "$EUID" -ne 0 ]; then
|
||||||
log_error "This script must be run as root"
|
log_error "This script must be run as root"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
check_clash() {
|
check_clash() {
|
||||||
# Check for any clash process (clash, clash-linux-amd64, etc.)
|
# Check for any clash process (clash, clash-linux-amd64, etc.)
|
||||||
if ! pgrep -f "clash" > /dev/null 2>&1; then
|
if ! pgrep -f "clash" > /dev/null 2>&1; then
|
||||||
log_error "Clash is not running"
|
log_error "Clash is not running"
|
||||||
log_info "Please start Clash first"
|
log_info "Please start Clash first"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Check if Clash is listening on redir port
|
# Check if Clash is listening on redir port
|
||||||
if ! ss -tlnp | grep -q ":$CLASH_REDIR_PORT"; then
|
if ! ss -tlnp | grep -q ":$CLASH_REDIR_PORT"; then
|
||||||
log_warn "Clash redir port ($CLASH_REDIR_PORT) not listening"
|
log_warn "Clash redir port ($CLASH_REDIR_PORT) not listening"
|
||||||
log_info "Make sure your Clash config has:"
|
log_info "Make sure your Clash config has:"
|
||||||
echo " redir-port: $CLASH_REDIR_PORT"
|
echo " redir-port: $CLASH_REDIR_PORT"
|
||||||
echo " allow-lan: true"
|
echo " allow-lan: true"
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Enable Transparent Proxy
|
# Enable Transparent Proxy
|
||||||
# ============================================
|
# ============================================
|
||||||
enable_tproxy() {
|
enable_tproxy() {
|
||||||
check_root
|
check_root
|
||||||
check_clash
|
check_clash
|
||||||
|
|
||||||
log_info "Enabling transparent proxy..."
|
log_info "Enabling transparent proxy..."
|
||||||
|
|
||||||
# Enable IP forwarding
|
# Enable IP forwarding
|
||||||
log_info "Enabling IP forwarding..."
|
log_info "Enabling IP forwarding..."
|
||||||
echo 1 > /proc/sys/net/ipv4/ip_forward
|
echo 1 > /proc/sys/net/ipv4/ip_forward
|
||||||
sysctl -w net.ipv4.ip_forward=1 > /dev/null
|
sysctl -w net.ipv4.ip_forward=1 > /dev/null
|
||||||
|
|
||||||
# Make IP forwarding persistent
|
# Make IP forwarding persistent
|
||||||
if ! grep -q "net.ipv4.ip_forward=1" /etc/sysctl.conf; then
|
if ! grep -q "net.ipv4.ip_forward=1" /etc/sysctl.conf; then
|
||||||
echo "net.ipv4.ip_forward=1" >> /etc/sysctl.conf
|
echo "net.ipv4.ip_forward=1" >> /etc/sysctl.conf
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Create NAT chain for transparent proxy
|
# Create NAT chain for transparent proxy
|
||||||
log_info "Creating iptables rules..."
|
log_info "Creating iptables rules..."
|
||||||
|
|
||||||
# Remove existing rules if any
|
# Remove existing rules if any
|
||||||
iptables -t nat -D PREROUTING -j $CHAIN_NAME 2>/dev/null || true
|
iptables -t nat -D PREROUTING -j $CHAIN_NAME 2>/dev/null || true
|
||||||
iptables -t nat -F $CHAIN_NAME 2>/dev/null || true
|
iptables -t nat -F $CHAIN_NAME 2>/dev/null || true
|
||||||
iptables -t nat -X $CHAIN_NAME 2>/dev/null || true
|
iptables -t nat -X $CHAIN_NAME 2>/dev/null || true
|
||||||
|
|
||||||
# Create new chain
|
# Create new chain
|
||||||
iptables -t nat -N $CHAIN_NAME
|
iptables -t nat -N $CHAIN_NAME
|
||||||
|
|
||||||
# Bypass local and private networks
|
# Bypass local and private networks
|
||||||
for ip in $BYPASS_IPS; do
|
for ip in $BYPASS_IPS; do
|
||||||
iptables -t nat -A $CHAIN_NAME -d $ip -j RETURN
|
iptables -t nat -A $CHAIN_NAME -d $ip -j RETURN
|
||||||
done
|
done
|
||||||
|
|
||||||
# Bypass traffic to this gateway itself
|
# Bypass traffic to this gateway itself
|
||||||
iptables -t nat -A $CHAIN_NAME -d $GATEWAY_IP -j RETURN
|
iptables -t nat -A $CHAIN_NAME -d $GATEWAY_IP -j RETURN
|
||||||
|
|
||||||
# Only proxy traffic from specified clients
|
# Only proxy traffic from specified clients
|
||||||
for client in $PROXY_CLIENTS; do
|
for client in $PROXY_CLIENTS; do
|
||||||
log_info "Adding proxy rule for client: $client"
|
log_info "Adding proxy rule for client: $client"
|
||||||
# Redirect HTTP/HTTPS traffic to Clash redir port
|
# Redirect HTTP/HTTPS traffic to Clash redir port
|
||||||
iptables -t nat -A $CHAIN_NAME -s $client -p tcp --dport 80 -j REDIRECT --to-ports $CLASH_REDIR_PORT
|
iptables -t nat -A $CHAIN_NAME -s $client -p tcp --dport 80 -j REDIRECT --to-ports $CLASH_REDIR_PORT
|
||||||
iptables -t nat -A $CHAIN_NAME -s $client -p tcp --dport 443 -j REDIRECT --to-ports $CLASH_REDIR_PORT
|
iptables -t nat -A $CHAIN_NAME -s $client -p tcp --dport 443 -j REDIRECT --to-ports $CLASH_REDIR_PORT
|
||||||
# Redirect all other TCP traffic
|
# Redirect all other TCP traffic
|
||||||
iptables -t nat -A $CHAIN_NAME -s $client -p tcp -j REDIRECT --to-ports $CLASH_REDIR_PORT
|
iptables -t nat -A $CHAIN_NAME -s $client -p tcp -j REDIRECT --to-ports $CLASH_REDIR_PORT
|
||||||
done
|
done
|
||||||
|
|
||||||
# Apply the chain to PREROUTING
|
# Apply the chain to PREROUTING
|
||||||
iptables -t nat -A PREROUTING -j $CHAIN_NAME
|
iptables -t nat -A PREROUTING -j $CHAIN_NAME
|
||||||
|
|
||||||
# Setup DNS redirect (optional - redirect DNS to Clash DNS)
|
# Setup DNS redirect (optional - redirect DNS to Clash DNS)
|
||||||
if ss -ulnp | grep -q ":$CLASH_DNS_PORT"; then
|
if ss -ulnp | grep -q ":$CLASH_DNS_PORT"; then
|
||||||
log_info "Setting up DNS redirect to Clash DNS..."
|
log_info "Setting up DNS redirect to Clash DNS..."
|
||||||
for client in $PROXY_CLIENTS; do
|
for client in $PROXY_CLIENTS; do
|
||||||
iptables -t nat -A PREROUTING -s $client -p udp --dport 53 -j REDIRECT --to-ports $CLASH_DNS_PORT
|
iptables -t nat -A PREROUTING -s $client -p udp --dport 53 -j REDIRECT --to-ports $CLASH_DNS_PORT
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Ensure MASQUERADE for forwarded traffic
|
# Ensure MASQUERADE for forwarded traffic
|
||||||
iptables -t nat -A POSTROUTING -s $LAN_SUBNET -o $LAN_INTERFACE -j MASQUERADE 2>/dev/null || true
|
iptables -t nat -A POSTROUTING -s $LAN_SUBNET -o $LAN_INTERFACE -j MASQUERADE 2>/dev/null || true
|
||||||
|
|
||||||
log_info "Transparent proxy enabled!"
|
log_info "Transparent proxy enabled!"
|
||||||
log_info ""
|
log_info ""
|
||||||
log_info "Proxied clients: $PROXY_CLIENTS"
|
log_info "Proxied clients: $PROXY_CLIENTS"
|
||||||
log_info "Clash redir port: $CLASH_REDIR_PORT"
|
log_info "Clash redir port: $CLASH_REDIR_PORT"
|
||||||
log_info ""
|
log_info ""
|
||||||
log_info "Test from client (192.168.1.111):"
|
log_info "Test from client (192.168.1.111):"
|
||||||
log_info " curl -I https://www.google.com"
|
log_info " curl -I https://www.google.com"
|
||||||
}
|
}
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Disable Transparent Proxy
|
# Disable Transparent Proxy
|
||||||
# ============================================
|
# ============================================
|
||||||
disable_tproxy() {
|
disable_tproxy() {
|
||||||
check_root
|
check_root
|
||||||
|
|
||||||
log_info "Disabling transparent proxy..."
|
log_info "Disabling transparent proxy..."
|
||||||
|
|
||||||
# Remove DNS redirect rules
|
# Remove DNS redirect rules
|
||||||
for client in $PROXY_CLIENTS; do
|
for client in $PROXY_CLIENTS; do
|
||||||
iptables -t nat -D PREROUTING -s $client -p udp --dport 53 -j REDIRECT --to-ports $CLASH_DNS_PORT 2>/dev/null || true
|
iptables -t nat -D PREROUTING -s $client -p udp --dport 53 -j REDIRECT --to-ports $CLASH_DNS_PORT 2>/dev/null || true
|
||||||
done
|
done
|
||||||
|
|
||||||
# Remove the chain from PREROUTING
|
# Remove the chain from PREROUTING
|
||||||
iptables -t nat -D PREROUTING -j $CHAIN_NAME 2>/dev/null || true
|
iptables -t nat -D PREROUTING -j $CHAIN_NAME 2>/dev/null || true
|
||||||
|
|
||||||
# Flush and delete the chain
|
# Flush and delete the chain
|
||||||
iptables -t nat -F $CHAIN_NAME 2>/dev/null || true
|
iptables -t nat -F $CHAIN_NAME 2>/dev/null || true
|
||||||
iptables -t nat -X $CHAIN_NAME 2>/dev/null || true
|
iptables -t nat -X $CHAIN_NAME 2>/dev/null || true
|
||||||
|
|
||||||
log_info "Transparent proxy disabled!"
|
log_info "Transparent proxy disabled!"
|
||||||
log_info ""
|
log_info ""
|
||||||
log_info "Clients will now access internet directly (through NAT only)"
|
log_info "Clients will now access internet directly (through NAT only)"
|
||||||
}
|
}
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Check Status
|
# Check Status
|
||||||
# ============================================
|
# ============================================
|
||||||
show_status() {
|
show_status() {
|
||||||
echo ""
|
echo ""
|
||||||
echo "============================================"
|
echo "============================================"
|
||||||
echo "Transparent Proxy Status"
|
echo "Transparent Proxy Status"
|
||||||
echo "============================================"
|
echo "============================================"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Check IP forwarding
|
# Check IP forwarding
|
||||||
echo "IP Forwarding:"
|
echo "IP Forwarding:"
|
||||||
if [ "$(cat /proc/sys/net/ipv4/ip_forward)" = "1" ]; then
|
if [ "$(cat /proc/sys/net/ipv4/ip_forward)" = "1" ]; then
|
||||||
echo -e " Status: ${GREEN}Enabled${NC}"
|
echo -e " Status: ${GREEN}Enabled${NC}"
|
||||||
else
|
else
|
||||||
echo -e " Status: ${RED}Disabled${NC}"
|
echo -e " Status: ${RED}Disabled${NC}"
|
||||||
fi
|
fi
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Check Clash
|
# Check Clash
|
||||||
echo "Clash Process:"
|
echo "Clash Process:"
|
||||||
if pgrep -f "clash" > /dev/null 2>&1; then
|
if pgrep -f "clash" > /dev/null 2>&1; then
|
||||||
echo -e " Status: ${GREEN}Running${NC}"
|
echo -e " Status: ${GREEN}Running${NC}"
|
||||||
echo " PID: $(pgrep -f clash | head -1)"
|
echo " PID: $(pgrep -f clash | head -1)"
|
||||||
else
|
else
|
||||||
echo -e " Status: ${RED}Not Running${NC}"
|
echo -e " Status: ${RED}Not Running${NC}"
|
||||||
fi
|
fi
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Check Clash ports
|
# Check Clash ports
|
||||||
echo "Clash Ports:"
|
echo "Clash Ports:"
|
||||||
echo -n " HTTP ($CLASH_HTTP_PORT): "
|
echo -n " HTTP ($CLASH_HTTP_PORT): "
|
||||||
ss -tlnp | grep -q ":$CLASH_HTTP_PORT" && echo -e "${GREEN}Listening${NC}" || echo -e "${RED}Not Listening${NC}"
|
ss -tlnp | grep -q ":$CLASH_HTTP_PORT" && echo -e "${GREEN}Listening${NC}" || echo -e "${RED}Not Listening${NC}"
|
||||||
echo -n " SOCKS ($CLASH_SOCKS_PORT): "
|
echo -n " SOCKS ($CLASH_SOCKS_PORT): "
|
||||||
ss -tlnp | grep -q ":$CLASH_SOCKS_PORT" && echo -e "${GREEN}Listening${NC}" || echo -e "${RED}Not Listening${NC}"
|
ss -tlnp | grep -q ":$CLASH_SOCKS_PORT" && echo -e "${GREEN}Listening${NC}" || echo -e "${RED}Not Listening${NC}"
|
||||||
echo -n " Redir ($CLASH_REDIR_PORT): "
|
echo -n " Redir ($CLASH_REDIR_PORT): "
|
||||||
ss -tlnp | grep -q ":$CLASH_REDIR_PORT" && echo -e "${GREEN}Listening${NC}" || echo -e "${RED}Not Listening${NC}"
|
ss -tlnp | grep -q ":$CLASH_REDIR_PORT" && echo -e "${GREEN}Listening${NC}" || echo -e "${RED}Not Listening${NC}"
|
||||||
echo -n " DNS ($CLASH_DNS_PORT): "
|
echo -n " DNS ($CLASH_DNS_PORT): "
|
||||||
ss -ulnp | grep -q ":$CLASH_DNS_PORT" && echo -e "${GREEN}Listening${NC}" || echo -e "${RED}Not Listening${NC}"
|
ss -ulnp | grep -q ":$CLASH_DNS_PORT" && echo -e "${GREEN}Listening${NC}" || echo -e "${RED}Not Listening${NC}"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Check iptables rules
|
# Check iptables rules
|
||||||
echo "iptables Transparent Proxy Chain:"
|
echo "iptables Transparent Proxy Chain:"
|
||||||
if iptables -t nat -L $CHAIN_NAME > /dev/null 2>&1; then
|
if iptables -t nat -L $CHAIN_NAME > /dev/null 2>&1; then
|
||||||
echo -e " Status: ${GREEN}Active${NC}"
|
echo -e " Status: ${GREEN}Active${NC}"
|
||||||
echo " Rules:"
|
echo " Rules:"
|
||||||
iptables -t nat -L $CHAIN_NAME -n --line-numbers 2>/dev/null | head -20
|
iptables -t nat -L $CHAIN_NAME -n --line-numbers 2>/dev/null | head -20
|
||||||
else
|
else
|
||||||
echo -e " Status: ${YELLOW}Not Active${NC}"
|
echo -e " Status: ${YELLOW}Not Active${NC}"
|
||||||
fi
|
fi
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Check PREROUTING
|
# Check PREROUTING
|
||||||
echo "PREROUTING Chain (first 10 rules):"
|
echo "PREROUTING Chain (first 10 rules):"
|
||||||
iptables -t nat -L PREROUTING -n --line-numbers | head -12
|
iptables -t nat -L PREROUTING -n --line-numbers | head -12
|
||||||
echo ""
|
echo ""
|
||||||
}
|
}
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Test Proxy from Client
|
# Test Proxy from Client
|
||||||
# ============================================
|
# ============================================
|
||||||
test_proxy() {
|
test_proxy() {
|
||||||
echo ""
|
echo ""
|
||||||
echo "============================================"
|
echo "============================================"
|
||||||
echo "Proxy Test Instructions"
|
echo "Proxy Test Instructions"
|
||||||
echo "============================================"
|
echo "============================================"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Run these commands on 192.168.1.111 to test:"
|
echo "Run these commands on 192.168.1.111 to test:"
|
||||||
echo ""
|
echo ""
|
||||||
echo "1. Test Google (requires proxy):"
|
echo "1. Test Google (requires proxy):"
|
||||||
echo " curl -I --connect-timeout 5 https://www.google.com"
|
echo " curl -I --connect-timeout 5 https://www.google.com"
|
||||||
echo ""
|
echo ""
|
||||||
echo "2. Test external IP:"
|
echo "2. Test external IP:"
|
||||||
echo " curl -s https://ipinfo.io/ip"
|
echo " curl -s https://ipinfo.io/ip"
|
||||||
echo ""
|
echo ""
|
||||||
echo "3. Test Docker Hub:"
|
echo "3. Test Docker Hub:"
|
||||||
echo " curl -I --connect-timeout 5 https://registry-1.docker.io/v2/"
|
echo " curl -I --connect-timeout 5 https://registry-1.docker.io/v2/"
|
||||||
echo ""
|
echo ""
|
||||||
echo "4. Test GitHub:"
|
echo "4. Test GitHub:"
|
||||||
echo " curl -I --connect-timeout 5 https://github.com"
|
echo " curl -I --connect-timeout 5 https://github.com"
|
||||||
echo ""
|
echo ""
|
||||||
}
|
}
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Show Required Clash Configuration
|
# Show Required Clash Configuration
|
||||||
# ============================================
|
# ============================================
|
||||||
show_clash_config() {
|
show_clash_config() {
|
||||||
echo ""
|
echo ""
|
||||||
echo "============================================"
|
echo "============================================"
|
||||||
echo "Required Clash Configuration"
|
echo "Required Clash Configuration"
|
||||||
echo "============================================"
|
echo "============================================"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Add these settings to your Clash config.yaml:"
|
echo "Add these settings to your Clash config.yaml:"
|
||||||
echo ""
|
echo ""
|
||||||
cat << 'EOF'
|
cat << 'EOF'
|
||||||
# Enable LAN access
|
# Enable LAN access
|
||||||
allow-lan: true
|
allow-lan: true
|
||||||
bind-address: "*"
|
bind-address: "*"
|
||||||
|
|
||||||
# Proxy ports
|
# Proxy ports
|
||||||
port: 7890 # HTTP proxy
|
port: 7890 # HTTP proxy
|
||||||
socks-port: 7891 # SOCKS5 proxy
|
socks-port: 7891 # SOCKS5 proxy
|
||||||
redir-port: 7892 # Transparent proxy (Linux only)
|
redir-port: 7892 # Transparent proxy (Linux only)
|
||||||
tproxy-port: 7893 # TProxy port (optional)
|
tproxy-port: 7893 # TProxy port (optional)
|
||||||
|
|
||||||
# DNS settings (optional but recommended)
|
# DNS settings (optional but recommended)
|
||||||
dns:
|
dns:
|
||||||
enable: true
|
enable: true
|
||||||
listen: 0.0.0.0:1053
|
listen: 0.0.0.0:1053
|
||||||
enhanced-mode: fake-ip
|
enhanced-mode: fake-ip
|
||||||
fake-ip-range: 198.18.0.1/16
|
fake-ip-range: 198.18.0.1/16
|
||||||
nameserver:
|
nameserver:
|
||||||
- 223.5.5.5
|
- 223.5.5.5
|
||||||
- 119.29.29.29
|
- 119.29.29.29
|
||||||
fallback:
|
fallback:
|
||||||
- 8.8.8.8
|
- 8.8.8.8
|
||||||
- 1.1.1.1
|
- 1.1.1.1
|
||||||
EOF
|
EOF
|
||||||
echo ""
|
echo ""
|
||||||
echo "After modifying, restart Clash:"
|
echo "After modifying, restart Clash:"
|
||||||
echo " systemctl restart clash"
|
echo " systemctl restart clash"
|
||||||
echo " # or"
|
echo " # or"
|
||||||
echo " killall clash && clash -d /path/to/config &"
|
echo " killall clash && clash -d /path/to/config &"
|
||||||
echo ""
|
echo ""
|
||||||
}
|
}
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Main
|
# Main
|
||||||
# ============================================
|
# ============================================
|
||||||
case "${1:-}" in
|
case "${1:-}" in
|
||||||
on|enable|start)
|
on|enable|start)
|
||||||
enable_tproxy
|
enable_tproxy
|
||||||
;;
|
;;
|
||||||
off|disable|stop)
|
off|disable|stop)
|
||||||
disable_tproxy
|
disable_tproxy
|
||||||
;;
|
;;
|
||||||
status)
|
status)
|
||||||
show_status
|
show_status
|
||||||
;;
|
;;
|
||||||
test)
|
test)
|
||||||
test_proxy
|
test_proxy
|
||||||
;;
|
;;
|
||||||
config)
|
config)
|
||||||
show_clash_config
|
show_clash_config
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Transparent Proxy Manager for Clash"
|
echo "Transparent Proxy Manager for Clash"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Usage: $0 {on|off|status|test|config}"
|
echo "Usage: $0 {on|off|status|test|config}"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Commands:"
|
echo "Commands:"
|
||||||
echo " on - Enable transparent proxy for LAN clients"
|
echo " on - Enable transparent proxy for LAN clients"
|
||||||
echo " off - Disable transparent proxy"
|
echo " off - Disable transparent proxy"
|
||||||
echo " status - Show current status"
|
echo " status - Show current status"
|
||||||
echo " test - Show test commands for clients"
|
echo " test - Show test commands for clients"
|
||||||
echo " config - Show required Clash configuration"
|
echo " config - Show required Clash configuration"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Environment Variables:"
|
echo "Environment Variables:"
|
||||||
echo " CLASH_REDIR_PORT - Clash redir port (default: 7892)"
|
echo " CLASH_REDIR_PORT - Clash redir port (default: 7892)"
|
||||||
echo " CLASH_DNS_PORT - Clash DNS port (default: 1053)"
|
echo " CLASH_DNS_PORT - Clash DNS port (default: 1053)"
|
||||||
echo " LAN_INTERFACE - LAN interface (default: eth0)"
|
echo " LAN_INTERFACE - LAN interface (default: eth0)"
|
||||||
echo " PROXY_CLIENTS - Space-separated client IPs (default: 192.168.1.111)"
|
echo " PROXY_CLIENTS - Space-separated client IPs (default: 192.168.1.111)"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Example:"
|
echo "Example:"
|
||||||
echo " sudo $0 on # Enable with defaults"
|
echo " sudo $0 on # Enable with defaults"
|
||||||
echo " sudo PROXY_CLIENTS='192.168.1.111 192.168.1.112' $0 on # Multiple clients"
|
echo " sudo PROXY_CLIENTS='192.168.1.111 192.168.1.112' $0 on # Multiple clients"
|
||||||
echo " sudo $0 off # Disable"
|
echo " sudo $0 off # Disable"
|
||||||
echo ""
|
echo ""
|
||||||
exit 1
|
exit 1
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
|
|
|
||||||
|
|
@ -1,38 +1,38 @@
|
||||||
# Build stage
|
# Build stage
|
||||||
FROM golang:1.21-alpine AS builder
|
FROM golang:1.21-alpine AS builder
|
||||||
|
|
||||||
RUN apk add --no-cache git ca-certificates
|
RUN apk add --no-cache git ca-certificates
|
||||||
|
|
||||||
# Set Go proxy (can be overridden with --build-arg GOPROXY=...)
|
# Set Go proxy (can be overridden with --build-arg GOPROXY=...)
|
||||||
ARG GOPROXY=https://proxy.golang.org,direct
|
ARG GOPROXY=https://proxy.golang.org,direct
|
||||||
ENV GOPROXY=${GOPROXY}
|
ENV GOPROXY=${GOPROXY}
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
COPY go.mod go.sum ./
|
COPY go.mod go.sum ./
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
||||||
-ldflags="-w -s" \
|
-ldflags="-w -s" \
|
||||||
-o /bin/account-service \
|
-o /bin/account-service \
|
||||||
./services/account/cmd/server
|
./services/account/cmd/server
|
||||||
|
|
||||||
# Final stage
|
# Final stage
|
||||||
FROM alpine:3.18
|
FROM alpine:3.18
|
||||||
|
|
||||||
RUN apk --no-cache add ca-certificates curl
|
RUN apk --no-cache add ca-certificates curl
|
||||||
RUN adduser -D -s /bin/sh mpc
|
RUN adduser -D -s /bin/sh mpc
|
||||||
|
|
||||||
COPY --from=builder /bin/account-service /bin/account-service
|
COPY --from=builder /bin/account-service /bin/account-service
|
||||||
|
|
||||||
USER mpc
|
USER mpc
|
||||||
|
|
||||||
EXPOSE 50051 8080
|
EXPOSE 50051 8080
|
||||||
|
|
||||||
# Health check
|
# Health check
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||||
CMD curl -sf http://localhost:8080/health || exit 1
|
CMD curl -sf http://localhost:8080/health || exit 1
|
||||||
|
|
||||||
ENTRYPOINT ["/bin/account-service"]
|
ENTRYPOINT ["/bin/account-service"]
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,486 +1,486 @@
|
||||||
package http
|
package http
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rwadurian/mpc-system/services/account/application/ports"
|
"github.com/rwadurian/mpc-system/services/account/application/ports"
|
||||||
"github.com/rwadurian/mpc-system/services/account/application/use_cases"
|
"github.com/rwadurian/mpc-system/services/account/application/use_cases"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AccountHTTPHandler handles HTTP requests for accounts
|
// AccountHTTPHandler handles HTTP requests for accounts
|
||||||
type AccountHTTPHandler struct {
|
type AccountHTTPHandler struct {
|
||||||
createAccountUC *use_cases.CreateAccountUseCase
|
createAccountUC *use_cases.CreateAccountUseCase
|
||||||
getAccountUC *use_cases.GetAccountUseCase
|
getAccountUC *use_cases.GetAccountUseCase
|
||||||
updateAccountUC *use_cases.UpdateAccountUseCase
|
updateAccountUC *use_cases.UpdateAccountUseCase
|
||||||
listAccountsUC *use_cases.ListAccountsUseCase
|
listAccountsUC *use_cases.ListAccountsUseCase
|
||||||
getAccountSharesUC *use_cases.GetAccountSharesUseCase
|
getAccountSharesUC *use_cases.GetAccountSharesUseCase
|
||||||
deactivateShareUC *use_cases.DeactivateShareUseCase
|
deactivateShareUC *use_cases.DeactivateShareUseCase
|
||||||
loginUC *use_cases.LoginUseCase
|
loginUC *use_cases.LoginUseCase
|
||||||
refreshTokenUC *use_cases.RefreshTokenUseCase
|
refreshTokenUC *use_cases.RefreshTokenUseCase
|
||||||
generateChallengeUC *use_cases.GenerateChallengeUseCase
|
generateChallengeUC *use_cases.GenerateChallengeUseCase
|
||||||
initiateRecoveryUC *use_cases.InitiateRecoveryUseCase
|
initiateRecoveryUC *use_cases.InitiateRecoveryUseCase
|
||||||
completeRecoveryUC *use_cases.CompleteRecoveryUseCase
|
completeRecoveryUC *use_cases.CompleteRecoveryUseCase
|
||||||
getRecoveryStatusUC *use_cases.GetRecoveryStatusUseCase
|
getRecoveryStatusUC *use_cases.GetRecoveryStatusUseCase
|
||||||
cancelRecoveryUC *use_cases.CancelRecoveryUseCase
|
cancelRecoveryUC *use_cases.CancelRecoveryUseCase
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountHTTPHandler creates a new AccountHTTPHandler
|
// NewAccountHTTPHandler creates a new AccountHTTPHandler
|
||||||
func NewAccountHTTPHandler(
|
func NewAccountHTTPHandler(
|
||||||
createAccountUC *use_cases.CreateAccountUseCase,
|
createAccountUC *use_cases.CreateAccountUseCase,
|
||||||
getAccountUC *use_cases.GetAccountUseCase,
|
getAccountUC *use_cases.GetAccountUseCase,
|
||||||
updateAccountUC *use_cases.UpdateAccountUseCase,
|
updateAccountUC *use_cases.UpdateAccountUseCase,
|
||||||
listAccountsUC *use_cases.ListAccountsUseCase,
|
listAccountsUC *use_cases.ListAccountsUseCase,
|
||||||
getAccountSharesUC *use_cases.GetAccountSharesUseCase,
|
getAccountSharesUC *use_cases.GetAccountSharesUseCase,
|
||||||
deactivateShareUC *use_cases.DeactivateShareUseCase,
|
deactivateShareUC *use_cases.DeactivateShareUseCase,
|
||||||
loginUC *use_cases.LoginUseCase,
|
loginUC *use_cases.LoginUseCase,
|
||||||
refreshTokenUC *use_cases.RefreshTokenUseCase,
|
refreshTokenUC *use_cases.RefreshTokenUseCase,
|
||||||
generateChallengeUC *use_cases.GenerateChallengeUseCase,
|
generateChallengeUC *use_cases.GenerateChallengeUseCase,
|
||||||
initiateRecoveryUC *use_cases.InitiateRecoveryUseCase,
|
initiateRecoveryUC *use_cases.InitiateRecoveryUseCase,
|
||||||
completeRecoveryUC *use_cases.CompleteRecoveryUseCase,
|
completeRecoveryUC *use_cases.CompleteRecoveryUseCase,
|
||||||
getRecoveryStatusUC *use_cases.GetRecoveryStatusUseCase,
|
getRecoveryStatusUC *use_cases.GetRecoveryStatusUseCase,
|
||||||
cancelRecoveryUC *use_cases.CancelRecoveryUseCase,
|
cancelRecoveryUC *use_cases.CancelRecoveryUseCase,
|
||||||
) *AccountHTTPHandler {
|
) *AccountHTTPHandler {
|
||||||
return &AccountHTTPHandler{
|
return &AccountHTTPHandler{
|
||||||
createAccountUC: createAccountUC,
|
createAccountUC: createAccountUC,
|
||||||
getAccountUC: getAccountUC,
|
getAccountUC: getAccountUC,
|
||||||
updateAccountUC: updateAccountUC,
|
updateAccountUC: updateAccountUC,
|
||||||
listAccountsUC: listAccountsUC,
|
listAccountsUC: listAccountsUC,
|
||||||
getAccountSharesUC: getAccountSharesUC,
|
getAccountSharesUC: getAccountSharesUC,
|
||||||
deactivateShareUC: deactivateShareUC,
|
deactivateShareUC: deactivateShareUC,
|
||||||
loginUC: loginUC,
|
loginUC: loginUC,
|
||||||
refreshTokenUC: refreshTokenUC,
|
refreshTokenUC: refreshTokenUC,
|
||||||
generateChallengeUC: generateChallengeUC,
|
generateChallengeUC: generateChallengeUC,
|
||||||
initiateRecoveryUC: initiateRecoveryUC,
|
initiateRecoveryUC: initiateRecoveryUC,
|
||||||
completeRecoveryUC: completeRecoveryUC,
|
completeRecoveryUC: completeRecoveryUC,
|
||||||
getRecoveryStatusUC: getRecoveryStatusUC,
|
getRecoveryStatusUC: getRecoveryStatusUC,
|
||||||
cancelRecoveryUC: cancelRecoveryUC,
|
cancelRecoveryUC: cancelRecoveryUC,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterRoutes registers HTTP routes
|
// RegisterRoutes registers HTTP routes
|
||||||
func (h *AccountHTTPHandler) RegisterRoutes(router *gin.RouterGroup) {
|
func (h *AccountHTTPHandler) RegisterRoutes(router *gin.RouterGroup) {
|
||||||
accounts := router.Group("/accounts")
|
accounts := router.Group("/accounts")
|
||||||
{
|
{
|
||||||
accounts.POST("", h.CreateAccount)
|
accounts.POST("", h.CreateAccount)
|
||||||
accounts.GET("", h.ListAccounts)
|
accounts.GET("", h.ListAccounts)
|
||||||
accounts.GET("/:id", h.GetAccount)
|
accounts.GET("/:id", h.GetAccount)
|
||||||
accounts.PUT("/:id", h.UpdateAccount)
|
accounts.PUT("/:id", h.UpdateAccount)
|
||||||
accounts.GET("/:id/shares", h.GetAccountShares)
|
accounts.GET("/:id/shares", h.GetAccountShares)
|
||||||
accounts.DELETE("/:id/shares/:shareId", h.DeactivateShare)
|
accounts.DELETE("/:id/shares/:shareId", h.DeactivateShare)
|
||||||
}
|
}
|
||||||
|
|
||||||
auth := router.Group("/auth")
|
auth := router.Group("/auth")
|
||||||
{
|
{
|
||||||
auth.POST("/challenge", h.GenerateChallenge)
|
auth.POST("/challenge", h.GenerateChallenge)
|
||||||
auth.POST("/login", h.Login)
|
auth.POST("/login", h.Login)
|
||||||
auth.POST("/refresh", h.RefreshToken)
|
auth.POST("/refresh", h.RefreshToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
recovery := router.Group("/recovery")
|
recovery := router.Group("/recovery")
|
||||||
{
|
{
|
||||||
recovery.POST("", h.InitiateRecovery)
|
recovery.POST("", h.InitiateRecovery)
|
||||||
recovery.GET("/:id", h.GetRecoveryStatus)
|
recovery.GET("/:id", h.GetRecoveryStatus)
|
||||||
recovery.POST("/:id/complete", h.CompleteRecovery)
|
recovery.POST("/:id/complete", h.CompleteRecovery)
|
||||||
recovery.POST("/:id/cancel", h.CancelRecovery)
|
recovery.POST("/:id/cancel", h.CancelRecovery)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAccountRequest represents the request for creating an account
|
// CreateAccountRequest represents the request for creating an account
|
||||||
type CreateAccountRequest struct {
|
type CreateAccountRequest struct {
|
||||||
Username string `json:"username" binding:"required"`
|
Username string `json:"username" binding:"required"`
|
||||||
Email string `json:"email" binding:"required,email"`
|
Email string `json:"email" binding:"required,email"`
|
||||||
Phone *string `json:"phone"`
|
Phone *string `json:"phone"`
|
||||||
PublicKey string `json:"publicKey" binding:"required"`
|
PublicKey string `json:"publicKey" binding:"required"`
|
||||||
KeygenSessionID string `json:"keygenSessionId" binding:"required"`
|
KeygenSessionID string `json:"keygenSessionId" binding:"required"`
|
||||||
ThresholdN int `json:"thresholdN" binding:"required,min=1"`
|
ThresholdN int `json:"thresholdN" binding:"required,min=1"`
|
||||||
ThresholdT int `json:"thresholdT" binding:"required,min=1"`
|
ThresholdT int `json:"thresholdT" binding:"required,min=1"`
|
||||||
Shares []ShareInput `json:"shares" binding:"required,min=1"`
|
Shares []ShareInput `json:"shares" binding:"required,min=1"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShareInput represents a share in the request
|
// ShareInput represents a share in the request
|
||||||
type ShareInput struct {
|
type ShareInput struct {
|
||||||
ShareType string `json:"shareType" binding:"required"`
|
ShareType string `json:"shareType" binding:"required"`
|
||||||
PartyID string `json:"partyId" binding:"required"`
|
PartyID string `json:"partyId" binding:"required"`
|
||||||
PartyIndex int `json:"partyIndex" binding:"required,min=0"`
|
PartyIndex int `json:"partyIndex" binding:"required,min=0"`
|
||||||
DeviceType *string `json:"deviceType"`
|
DeviceType *string `json:"deviceType"`
|
||||||
DeviceID *string `json:"deviceId"`
|
DeviceID *string `json:"deviceId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAccount handles account creation
|
// CreateAccount handles account creation
|
||||||
func (h *AccountHTTPHandler) CreateAccount(c *gin.Context) {
|
func (h *AccountHTTPHandler) CreateAccount(c *gin.Context) {
|
||||||
var req CreateAccountRequest
|
var req CreateAccountRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
keygenSessionID, err := uuid.Parse(req.KeygenSessionID)
|
keygenSessionID, err := uuid.Parse(req.KeygenSessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid keygen session ID"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid keygen session ID"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
shares := make([]ports.ShareInput, len(req.Shares))
|
shares := make([]ports.ShareInput, len(req.Shares))
|
||||||
for i, s := range req.Shares {
|
for i, s := range req.Shares {
|
||||||
shares[i] = ports.ShareInput{
|
shares[i] = ports.ShareInput{
|
||||||
ShareType: value_objects.ShareType(s.ShareType),
|
ShareType: value_objects.ShareType(s.ShareType),
|
||||||
PartyID: s.PartyID,
|
PartyID: s.PartyID,
|
||||||
PartyIndex: s.PartyIndex,
|
PartyIndex: s.PartyIndex,
|
||||||
DeviceType: s.DeviceType,
|
DeviceType: s.DeviceType,
|
||||||
DeviceID: s.DeviceID,
|
DeviceID: s.DeviceID,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
output, err := h.createAccountUC.Execute(c.Request.Context(), ports.CreateAccountInput{
|
output, err := h.createAccountUC.Execute(c.Request.Context(), ports.CreateAccountInput{
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Email: req.Email,
|
Email: req.Email,
|
||||||
Phone: req.Phone,
|
Phone: req.Phone,
|
||||||
PublicKey: []byte(req.PublicKey),
|
PublicKey: []byte(req.PublicKey),
|
||||||
KeygenSessionID: keygenSessionID,
|
KeygenSessionID: keygenSessionID,
|
||||||
ThresholdN: req.ThresholdN,
|
ThresholdN: req.ThresholdN,
|
||||||
ThresholdT: req.ThresholdT,
|
ThresholdT: req.ThresholdT,
|
||||||
Shares: shares,
|
Shares: shares,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusCreated, gin.H{
|
c.JSON(http.StatusCreated, gin.H{
|
||||||
"account": output.Account,
|
"account": output.Account,
|
||||||
"shares": output.Shares,
|
"shares": output.Shares,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccount handles getting account by ID
|
// GetAccount handles getting account by ID
|
||||||
func (h *AccountHTTPHandler) GetAccount(c *gin.Context) {
|
func (h *AccountHTTPHandler) GetAccount(c *gin.Context) {
|
||||||
idStr := c.Param("id")
|
idStr := c.Param("id")
|
||||||
accountID, err := value_objects.AccountIDFromString(idStr)
|
accountID, err := value_objects.AccountIDFromString(idStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid account ID"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid account ID"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
output, err := h.getAccountUC.Execute(c.Request.Context(), ports.GetAccountInput{
|
output, err := h.getAccountUC.Execute(c.Request.Context(), ports.GetAccountInput{
|
||||||
AccountID: &accountID,
|
AccountID: &accountID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"account": output.Account,
|
"account": output.Account,
|
||||||
"shares": output.Shares,
|
"shares": output.Shares,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccountRequest represents the request for updating an account
|
// UpdateAccountRequest represents the request for updating an account
|
||||||
type UpdateAccountRequest struct {
|
type UpdateAccountRequest struct {
|
||||||
Phone *string `json:"phone"`
|
Phone *string `json:"phone"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccount handles account updates
|
// UpdateAccount handles account updates
|
||||||
func (h *AccountHTTPHandler) UpdateAccount(c *gin.Context) {
|
func (h *AccountHTTPHandler) UpdateAccount(c *gin.Context) {
|
||||||
idStr := c.Param("id")
|
idStr := c.Param("id")
|
||||||
accountID, err := value_objects.AccountIDFromString(idStr)
|
accountID, err := value_objects.AccountIDFromString(idStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid account ID"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid account ID"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req UpdateAccountRequest
|
var req UpdateAccountRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
output, err := h.updateAccountUC.Execute(c.Request.Context(), ports.UpdateAccountInput{
|
output, err := h.updateAccountUC.Execute(c.Request.Context(), ports.UpdateAccountInput{
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
Phone: req.Phone,
|
Phone: req.Phone,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, output.Account)
|
c.JSON(http.StatusOK, output.Account)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListAccounts handles listing accounts
|
// ListAccounts handles listing accounts
|
||||||
func (h *AccountHTTPHandler) ListAccounts(c *gin.Context) {
|
func (h *AccountHTTPHandler) ListAccounts(c *gin.Context) {
|
||||||
var offset, limit int
|
var offset, limit int
|
||||||
if o := c.Query("offset"); o != "" {
|
if o := c.Query("offset"); o != "" {
|
||||||
// Parse offset
|
// Parse offset
|
||||||
}
|
}
|
||||||
if l := c.Query("limit"); l != "" {
|
if l := c.Query("limit"); l != "" {
|
||||||
// Parse limit
|
// Parse limit
|
||||||
}
|
}
|
||||||
|
|
||||||
output, err := h.listAccountsUC.Execute(c.Request.Context(), use_cases.ListAccountsInput{
|
output, err := h.listAccountsUC.Execute(c.Request.Context(), use_cases.ListAccountsInput{
|
||||||
Offset: offset,
|
Offset: offset,
|
||||||
Limit: limit,
|
Limit: limit,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"accounts": output.Accounts,
|
"accounts": output.Accounts,
|
||||||
"total": output.Total,
|
"total": output.Total,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountShares handles getting account shares
|
// GetAccountShares handles getting account shares
|
||||||
func (h *AccountHTTPHandler) GetAccountShares(c *gin.Context) {
|
func (h *AccountHTTPHandler) GetAccountShares(c *gin.Context) {
|
||||||
idStr := c.Param("id")
|
idStr := c.Param("id")
|
||||||
accountID, err := value_objects.AccountIDFromString(idStr)
|
accountID, err := value_objects.AccountIDFromString(idStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid account ID"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid account ID"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
output, err := h.getAccountSharesUC.Execute(c.Request.Context(), accountID)
|
output, err := h.getAccountSharesUC.Execute(c.Request.Context(), accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"shares": output.Shares,
|
"shares": output.Shares,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeactivateShare handles share deactivation
|
// DeactivateShare handles share deactivation
|
||||||
func (h *AccountHTTPHandler) DeactivateShare(c *gin.Context) {
|
func (h *AccountHTTPHandler) DeactivateShare(c *gin.Context) {
|
||||||
idStr := c.Param("id")
|
idStr := c.Param("id")
|
||||||
accountID, err := value_objects.AccountIDFromString(idStr)
|
accountID, err := value_objects.AccountIDFromString(idStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid account ID"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid account ID"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
shareID := c.Param("shareId")
|
shareID := c.Param("shareId")
|
||||||
|
|
||||||
err = h.deactivateShareUC.Execute(c.Request.Context(), ports.DeactivateShareInput{
|
err = h.deactivateShareUC.Execute(c.Request.Context(), ports.DeactivateShareInput{
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
ShareID: shareID,
|
ShareID: shareID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "share deactivated"})
|
c.JSON(http.StatusOK, gin.H{"message": "share deactivated"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateChallengeRequest represents the request for generating a challenge
|
// GenerateChallengeRequest represents the request for generating a challenge
|
||||||
type GenerateChallengeRequest struct {
|
type GenerateChallengeRequest struct {
|
||||||
Username string `json:"username" binding:"required"`
|
Username string `json:"username" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateChallenge handles challenge generation
|
// GenerateChallenge handles challenge generation
|
||||||
func (h *AccountHTTPHandler) GenerateChallenge(c *gin.Context) {
|
func (h *AccountHTTPHandler) GenerateChallenge(c *gin.Context) {
|
||||||
var req GenerateChallengeRequest
|
var req GenerateChallengeRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
output, err := h.generateChallengeUC.Execute(c.Request.Context(), use_cases.GenerateChallengeInput{
|
output, err := h.generateChallengeUC.Execute(c.Request.Context(), use_cases.GenerateChallengeInput{
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"challengeId": output.ChallengeID,
|
"challengeId": output.ChallengeID,
|
||||||
"challenge": output.Challenge,
|
"challenge": output.Challenge,
|
||||||
"expiresAt": output.ExpiresAt,
|
"expiresAt": output.ExpiresAt,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginRequest represents the request for login
|
// LoginRequest represents the request for login
|
||||||
type LoginRequest struct {
|
type LoginRequest struct {
|
||||||
Username string `json:"username" binding:"required"`
|
Username string `json:"username" binding:"required"`
|
||||||
Challenge string `json:"challenge" binding:"required"`
|
Challenge string `json:"challenge" binding:"required"`
|
||||||
Signature string `json:"signature" binding:"required"`
|
Signature string `json:"signature" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Login handles user login
|
// Login handles user login
|
||||||
func (h *AccountHTTPHandler) Login(c *gin.Context) {
|
func (h *AccountHTTPHandler) Login(c *gin.Context) {
|
||||||
var req LoginRequest
|
var req LoginRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
output, err := h.loginUC.Execute(c.Request.Context(), ports.LoginInput{
|
output, err := h.loginUC.Execute(c.Request.Context(), ports.LoginInput{
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Challenge: []byte(req.Challenge),
|
Challenge: []byte(req.Challenge),
|
||||||
Signature: []byte(req.Signature),
|
Signature: []byte(req.Signature),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"account": output.Account,
|
"account": output.Account,
|
||||||
"accessToken": output.AccessToken,
|
"accessToken": output.AccessToken,
|
||||||
"refreshToken": output.RefreshToken,
|
"refreshToken": output.RefreshToken,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshTokenRequest represents the request for refreshing tokens
|
// RefreshTokenRequest represents the request for refreshing tokens
|
||||||
type RefreshTokenRequest struct {
|
type RefreshTokenRequest struct {
|
||||||
RefreshToken string `json:"refreshToken" binding:"required"`
|
RefreshToken string `json:"refreshToken" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshToken handles token refresh
|
// RefreshToken handles token refresh
|
||||||
func (h *AccountHTTPHandler) RefreshToken(c *gin.Context) {
|
func (h *AccountHTTPHandler) RefreshToken(c *gin.Context) {
|
||||||
var req RefreshTokenRequest
|
var req RefreshTokenRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
output, err := h.refreshTokenUC.Execute(c.Request.Context(), use_cases.RefreshTokenInput{
|
output, err := h.refreshTokenUC.Execute(c.Request.Context(), use_cases.RefreshTokenInput{
|
||||||
RefreshToken: req.RefreshToken,
|
RefreshToken: req.RefreshToken,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"accessToken": output.AccessToken,
|
"accessToken": output.AccessToken,
|
||||||
"refreshToken": output.RefreshToken,
|
"refreshToken": output.RefreshToken,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitiateRecoveryRequest represents the request for initiating recovery
|
// InitiateRecoveryRequest represents the request for initiating recovery
|
||||||
type InitiateRecoveryRequest struct {
|
type InitiateRecoveryRequest struct {
|
||||||
AccountID string `json:"accountId" binding:"required"`
|
AccountID string `json:"accountId" binding:"required"`
|
||||||
RecoveryType string `json:"recoveryType" binding:"required"`
|
RecoveryType string `json:"recoveryType" binding:"required"`
|
||||||
OldShareType *string `json:"oldShareType"`
|
OldShareType *string `json:"oldShareType"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitiateRecovery handles recovery initiation
|
// InitiateRecovery handles recovery initiation
|
||||||
func (h *AccountHTTPHandler) InitiateRecovery(c *gin.Context) {
|
func (h *AccountHTTPHandler) InitiateRecovery(c *gin.Context) {
|
||||||
var req InitiateRecoveryRequest
|
var req InitiateRecoveryRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accountID, err := value_objects.AccountIDFromString(req.AccountID)
|
accountID, err := value_objects.AccountIDFromString(req.AccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid account ID"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid account ID"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
input := ports.InitiateRecoveryInput{
|
input := ports.InitiateRecoveryInput{
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
RecoveryType: value_objects.RecoveryType(req.RecoveryType),
|
RecoveryType: value_objects.RecoveryType(req.RecoveryType),
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.OldShareType != nil {
|
if req.OldShareType != nil {
|
||||||
st := value_objects.ShareType(*req.OldShareType)
|
st := value_objects.ShareType(*req.OldShareType)
|
||||||
input.OldShareType = &st
|
input.OldShareType = &st
|
||||||
}
|
}
|
||||||
|
|
||||||
output, err := h.initiateRecoveryUC.Execute(c.Request.Context(), input)
|
output, err := h.initiateRecoveryUC.Execute(c.Request.Context(), input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusCreated, gin.H{
|
c.JSON(http.StatusCreated, gin.H{
|
||||||
"recoverySession": output.RecoverySession,
|
"recoverySession": output.RecoverySession,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRecoveryStatus handles getting recovery status
|
// GetRecoveryStatus handles getting recovery status
|
||||||
func (h *AccountHTTPHandler) GetRecoveryStatus(c *gin.Context) {
|
func (h *AccountHTTPHandler) GetRecoveryStatus(c *gin.Context) {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
|
|
||||||
output, err := h.getRecoveryStatusUC.Execute(c.Request.Context(), use_cases.GetRecoveryStatusInput{
|
output, err := h.getRecoveryStatusUC.Execute(c.Request.Context(), use_cases.GetRecoveryStatusInput{
|
||||||
RecoverySessionID: id,
|
RecoverySessionID: id,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, output.RecoverySession)
|
c.JSON(http.StatusOK, output.RecoverySession)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CompleteRecoveryRequest represents the request for completing recovery
|
// CompleteRecoveryRequest represents the request for completing recovery
|
||||||
type CompleteRecoveryRequest struct {
|
type CompleteRecoveryRequest struct {
|
||||||
NewPublicKey string `json:"newPublicKey" binding:"required"`
|
NewPublicKey string `json:"newPublicKey" binding:"required"`
|
||||||
NewKeygenSessionID string `json:"newKeygenSessionId" binding:"required"`
|
NewKeygenSessionID string `json:"newKeygenSessionId" binding:"required"`
|
||||||
NewShares []ShareInput `json:"newShares" binding:"required,min=1"`
|
NewShares []ShareInput `json:"newShares" binding:"required,min=1"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CompleteRecovery handles recovery completion
|
// CompleteRecovery handles recovery completion
|
||||||
func (h *AccountHTTPHandler) CompleteRecovery(c *gin.Context) {
|
func (h *AccountHTTPHandler) CompleteRecovery(c *gin.Context) {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
|
|
||||||
var req CompleteRecoveryRequest
|
var req CompleteRecoveryRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newKeygenSessionID, err := uuid.Parse(req.NewKeygenSessionID)
|
newKeygenSessionID, err := uuid.Parse(req.NewKeygenSessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid keygen session ID"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid keygen session ID"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newShares := make([]ports.ShareInput, len(req.NewShares))
|
newShares := make([]ports.ShareInput, len(req.NewShares))
|
||||||
for i, s := range req.NewShares {
|
for i, s := range req.NewShares {
|
||||||
newShares[i] = ports.ShareInput{
|
newShares[i] = ports.ShareInput{
|
||||||
ShareType: value_objects.ShareType(s.ShareType),
|
ShareType: value_objects.ShareType(s.ShareType),
|
||||||
PartyID: s.PartyID,
|
PartyID: s.PartyID,
|
||||||
PartyIndex: s.PartyIndex,
|
PartyIndex: s.PartyIndex,
|
||||||
DeviceType: s.DeviceType,
|
DeviceType: s.DeviceType,
|
||||||
DeviceID: s.DeviceID,
|
DeviceID: s.DeviceID,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
output, err := h.completeRecoveryUC.Execute(c.Request.Context(), ports.CompleteRecoveryInput{
|
output, err := h.completeRecoveryUC.Execute(c.Request.Context(), ports.CompleteRecoveryInput{
|
||||||
RecoverySessionID: id,
|
RecoverySessionID: id,
|
||||||
NewPublicKey: []byte(req.NewPublicKey),
|
NewPublicKey: []byte(req.NewPublicKey),
|
||||||
NewKeygenSessionID: newKeygenSessionID,
|
NewKeygenSessionID: newKeygenSessionID,
|
||||||
NewShares: newShares,
|
NewShares: newShares,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, output.Account)
|
c.JSON(http.StatusOK, output.Account)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CancelRecovery handles recovery cancellation
|
// CancelRecovery handles recovery cancellation
|
||||||
func (h *AccountHTTPHandler) CancelRecovery(c *gin.Context) {
|
func (h *AccountHTTPHandler) CancelRecovery(c *gin.Context) {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
|
|
||||||
err := h.cancelRecoveryUC.Execute(c.Request.Context(), use_cases.CancelRecoveryInput{
|
err := h.cancelRecoveryUC.Execute(c.Request.Context(), use_cases.CancelRecoveryInput{
|
||||||
RecoverySessionID: id,
|
RecoverySessionID: id,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "recovery cancelled"})
|
c.JSON(http.StatusOK, gin.H{"message": "recovery cancelled"})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,54 +1,54 @@
|
||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/rwadurian/mpc-system/pkg/jwt"
|
"github.com/rwadurian/mpc-system/pkg/jwt"
|
||||||
"github.com/rwadurian/mpc-system/services/account/application/ports"
|
"github.com/rwadurian/mpc-system/services/account/application/ports"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TokenServiceAdapter implements TokenService using JWT
|
// TokenServiceAdapter implements TokenService using JWT
|
||||||
type TokenServiceAdapter struct {
|
type TokenServiceAdapter struct {
|
||||||
jwtService *jwt.JWTService
|
jwtService *jwt.JWTService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTokenServiceAdapter creates a new TokenServiceAdapter
|
// NewTokenServiceAdapter creates a new TokenServiceAdapter
|
||||||
func NewTokenServiceAdapter(jwtService *jwt.JWTService) ports.TokenService {
|
func NewTokenServiceAdapter(jwtService *jwt.JWTService) ports.TokenService {
|
||||||
return &TokenServiceAdapter{jwtService: jwtService}
|
return &TokenServiceAdapter{jwtService: jwtService}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateAccessToken generates an access token for an account
|
// GenerateAccessToken generates an access token for an account
|
||||||
func (t *TokenServiceAdapter) GenerateAccessToken(accountID, username string) (string, error) {
|
func (t *TokenServiceAdapter) GenerateAccessToken(accountID, username string) (string, error) {
|
||||||
return t.jwtService.GenerateAccessToken(accountID, username)
|
return t.jwtService.GenerateAccessToken(accountID, username)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateRefreshToken generates a refresh token for an account
|
// GenerateRefreshToken generates a refresh token for an account
|
||||||
func (t *TokenServiceAdapter) GenerateRefreshToken(accountID string) (string, error) {
|
func (t *TokenServiceAdapter) GenerateRefreshToken(accountID string) (string, error) {
|
||||||
return t.jwtService.GenerateRefreshToken(accountID)
|
return t.jwtService.GenerateRefreshToken(accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateAccessToken validates an access token
|
// ValidateAccessToken validates an access token
|
||||||
func (t *TokenServiceAdapter) ValidateAccessToken(token string) (claims map[string]interface{}, err error) {
|
func (t *TokenServiceAdapter) ValidateAccessToken(token string) (claims map[string]interface{}, err error) {
|
||||||
accessClaims, err := t.jwtService.ValidateAccessToken(token)
|
accessClaims, err := t.jwtService.ValidateAccessToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return map[string]interface{}{
|
return map[string]interface{}{
|
||||||
"subject": accessClaims.Subject,
|
"subject": accessClaims.Subject,
|
||||||
"username": accessClaims.Username,
|
"username": accessClaims.Username,
|
||||||
"issuer": accessClaims.Issuer,
|
"issuer": accessClaims.Issuer,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateRefreshToken validates a refresh token
|
// ValidateRefreshToken validates a refresh token
|
||||||
func (t *TokenServiceAdapter) ValidateRefreshToken(token string) (accountID string, err error) {
|
func (t *TokenServiceAdapter) ValidateRefreshToken(token string) (accountID string, err error) {
|
||||||
claims, err := t.jwtService.ValidateRefreshToken(token)
|
claims, err := t.jwtService.ValidateRefreshToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return claims.Subject, nil
|
return claims.Subject, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshAccessToken refreshes an access token using a refresh token
|
// RefreshAccessToken refreshes an access token using a refresh token
|
||||||
func (t *TokenServiceAdapter) RefreshAccessToken(refreshToken string) (accessToken string, err error) {
|
func (t *TokenServiceAdapter) RefreshAccessToken(refreshToken string) (accessToken string, err error) {
|
||||||
return t.jwtService.RefreshAccessToken(refreshToken)
|
return t.jwtService.RefreshAccessToken(refreshToken)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,316 +1,316 @@
|
||||||
package postgres
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/repositories"
|
"github.com/rwadurian/mpc-system/services/account/domain/repositories"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AccountPostgresRepo implements AccountRepository using PostgreSQL
|
// AccountPostgresRepo implements AccountRepository using PostgreSQL
|
||||||
type AccountPostgresRepo struct {
|
type AccountPostgresRepo struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountPostgresRepo creates a new AccountPostgresRepo
|
// NewAccountPostgresRepo creates a new AccountPostgresRepo
|
||||||
func NewAccountPostgresRepo(db *sql.DB) repositories.AccountRepository {
|
func NewAccountPostgresRepo(db *sql.DB) repositories.AccountRepository {
|
||||||
return &AccountPostgresRepo{db: db}
|
return &AccountPostgresRepo{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create creates a new account
|
// Create creates a new account
|
||||||
func (r *AccountPostgresRepo) Create(ctx context.Context, account *entities.Account) error {
|
func (r *AccountPostgresRepo) Create(ctx context.Context, account *entities.Account) error {
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO accounts (id, username, email, phone, public_key, keygen_session_id,
|
INSERT INTO accounts (id, username, email, phone, public_key, keygen_session_id,
|
||||||
threshold_n, threshold_t, status, created_at, updated_at, last_login_at)
|
threshold_n, threshold_t, status, created_at, updated_at, last_login_at)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||||
`
|
`
|
||||||
|
|
||||||
_, err := r.db.ExecContext(ctx, query,
|
_, err := r.db.ExecContext(ctx, query,
|
||||||
account.ID.UUID(),
|
account.ID.UUID(),
|
||||||
account.Username,
|
account.Username,
|
||||||
account.Email,
|
account.Email,
|
||||||
account.Phone,
|
account.Phone,
|
||||||
account.PublicKey,
|
account.PublicKey,
|
||||||
account.KeygenSessionID,
|
account.KeygenSessionID,
|
||||||
account.ThresholdN,
|
account.ThresholdN,
|
||||||
account.ThresholdT,
|
account.ThresholdT,
|
||||||
account.Status.String(),
|
account.Status.String(),
|
||||||
account.CreatedAt,
|
account.CreatedAt,
|
||||||
account.UpdatedAt,
|
account.UpdatedAt,
|
||||||
account.LastLoginAt,
|
account.LastLoginAt,
|
||||||
)
|
)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByID retrieves an account by ID
|
// GetByID retrieves an account by ID
|
||||||
func (r *AccountPostgresRepo) GetByID(ctx context.Context, id value_objects.AccountID) (*entities.Account, error) {
|
func (r *AccountPostgresRepo) GetByID(ctx context.Context, id value_objects.AccountID) (*entities.Account, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, username, email, phone, public_key, keygen_session_id,
|
SELECT id, username, email, phone, public_key, keygen_session_id,
|
||||||
threshold_n, threshold_t, status, created_at, updated_at, last_login_at
|
threshold_n, threshold_t, status, created_at, updated_at, last_login_at
|
||||||
FROM accounts
|
FROM accounts
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
return r.scanAccount(r.db.QueryRowContext(ctx, query, id.UUID()))
|
return r.scanAccount(r.db.QueryRowContext(ctx, query, id.UUID()))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByUsername retrieves an account by username
|
// GetByUsername retrieves an account by username
|
||||||
func (r *AccountPostgresRepo) GetByUsername(ctx context.Context, username string) (*entities.Account, error) {
|
func (r *AccountPostgresRepo) GetByUsername(ctx context.Context, username string) (*entities.Account, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, username, email, phone, public_key, keygen_session_id,
|
SELECT id, username, email, phone, public_key, keygen_session_id,
|
||||||
threshold_n, threshold_t, status, created_at, updated_at, last_login_at
|
threshold_n, threshold_t, status, created_at, updated_at, last_login_at
|
||||||
FROM accounts
|
FROM accounts
|
||||||
WHERE username = $1
|
WHERE username = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
return r.scanAccount(r.db.QueryRowContext(ctx, query, username))
|
return r.scanAccount(r.db.QueryRowContext(ctx, query, username))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByEmail retrieves an account by email
|
// GetByEmail retrieves an account by email
|
||||||
func (r *AccountPostgresRepo) GetByEmail(ctx context.Context, email string) (*entities.Account, error) {
|
func (r *AccountPostgresRepo) GetByEmail(ctx context.Context, email string) (*entities.Account, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, username, email, phone, public_key, keygen_session_id,
|
SELECT id, username, email, phone, public_key, keygen_session_id,
|
||||||
threshold_n, threshold_t, status, created_at, updated_at, last_login_at
|
threshold_n, threshold_t, status, created_at, updated_at, last_login_at
|
||||||
FROM accounts
|
FROM accounts
|
||||||
WHERE email = $1
|
WHERE email = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
return r.scanAccount(r.db.QueryRowContext(ctx, query, email))
|
return r.scanAccount(r.db.QueryRowContext(ctx, query, email))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByPublicKey retrieves an account by public key
|
// GetByPublicKey retrieves an account by public key
|
||||||
func (r *AccountPostgresRepo) GetByPublicKey(ctx context.Context, publicKey []byte) (*entities.Account, error) {
|
func (r *AccountPostgresRepo) GetByPublicKey(ctx context.Context, publicKey []byte) (*entities.Account, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, username, email, phone, public_key, keygen_session_id,
|
SELECT id, username, email, phone, public_key, keygen_session_id,
|
||||||
threshold_n, threshold_t, status, created_at, updated_at, last_login_at
|
threshold_n, threshold_t, status, created_at, updated_at, last_login_at
|
||||||
FROM accounts
|
FROM accounts
|
||||||
WHERE public_key = $1
|
WHERE public_key = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
return r.scanAccount(r.db.QueryRowContext(ctx, query, publicKey))
|
return r.scanAccount(r.db.QueryRowContext(ctx, query, publicKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update updates an existing account
|
// Update updates an existing account
|
||||||
func (r *AccountPostgresRepo) Update(ctx context.Context, account *entities.Account) error {
|
func (r *AccountPostgresRepo) Update(ctx context.Context, account *entities.Account) error {
|
||||||
query := `
|
query := `
|
||||||
UPDATE accounts
|
UPDATE accounts
|
||||||
SET username = $2, email = $3, phone = $4, public_key = $5, keygen_session_id = $6,
|
SET username = $2, email = $3, phone = $4, public_key = $5, keygen_session_id = $6,
|
||||||
threshold_n = $7, threshold_t = $8, status = $9, updated_at = $10, last_login_at = $11
|
threshold_n = $7, threshold_t = $8, status = $9, updated_at = $10, last_login_at = $11
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
result, err := r.db.ExecContext(ctx, query,
|
result, err := r.db.ExecContext(ctx, query,
|
||||||
account.ID.UUID(),
|
account.ID.UUID(),
|
||||||
account.Username,
|
account.Username,
|
||||||
account.Email,
|
account.Email,
|
||||||
account.Phone,
|
account.Phone,
|
||||||
account.PublicKey,
|
account.PublicKey,
|
||||||
account.KeygenSessionID,
|
account.KeygenSessionID,
|
||||||
account.ThresholdN,
|
account.ThresholdN,
|
||||||
account.ThresholdT,
|
account.ThresholdT,
|
||||||
account.Status.String(),
|
account.Status.String(),
|
||||||
account.UpdatedAt,
|
account.UpdatedAt,
|
||||||
account.LastLoginAt,
|
account.LastLoginAt,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
rowsAffected, err := result.RowsAffected()
|
rowsAffected, err := result.RowsAffected()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if rowsAffected == 0 {
|
if rowsAffected == 0 {
|
||||||
return entities.ErrAccountNotFound
|
return entities.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete deletes an account
|
// Delete deletes an account
|
||||||
func (r *AccountPostgresRepo) Delete(ctx context.Context, id value_objects.AccountID) error {
|
func (r *AccountPostgresRepo) Delete(ctx context.Context, id value_objects.AccountID) error {
|
||||||
query := `DELETE FROM accounts WHERE id = $1`
|
query := `DELETE FROM accounts WHERE id = $1`
|
||||||
|
|
||||||
result, err := r.db.ExecContext(ctx, query, id.UUID())
|
result, err := r.db.ExecContext(ctx, query, id.UUID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
rowsAffected, err := result.RowsAffected()
|
rowsAffected, err := result.RowsAffected()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if rowsAffected == 0 {
|
if rowsAffected == 0 {
|
||||||
return entities.ErrAccountNotFound
|
return entities.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExistsByUsername checks if username exists
|
// ExistsByUsername checks if username exists
|
||||||
func (r *AccountPostgresRepo) ExistsByUsername(ctx context.Context, username string) (bool, error) {
|
func (r *AccountPostgresRepo) ExistsByUsername(ctx context.Context, username string) (bool, error) {
|
||||||
query := `SELECT EXISTS(SELECT 1 FROM accounts WHERE username = $1)`
|
query := `SELECT EXISTS(SELECT 1 FROM accounts WHERE username = $1)`
|
||||||
|
|
||||||
var exists bool
|
var exists bool
|
||||||
err := r.db.QueryRowContext(ctx, query, username).Scan(&exists)
|
err := r.db.QueryRowContext(ctx, query, username).Scan(&exists)
|
||||||
return exists, err
|
return exists, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExistsByEmail checks if email exists
|
// ExistsByEmail checks if email exists
|
||||||
func (r *AccountPostgresRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
func (r *AccountPostgresRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||||
query := `SELECT EXISTS(SELECT 1 FROM accounts WHERE email = $1)`
|
query := `SELECT EXISTS(SELECT 1 FROM accounts WHERE email = $1)`
|
||||||
|
|
||||||
var exists bool
|
var exists bool
|
||||||
err := r.db.QueryRowContext(ctx, query, email).Scan(&exists)
|
err := r.db.QueryRowContext(ctx, query, email).Scan(&exists)
|
||||||
return exists, err
|
return exists, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// List lists accounts with pagination
|
// List lists accounts with pagination
|
||||||
func (r *AccountPostgresRepo) List(ctx context.Context, offset, limit int) ([]*entities.Account, error) {
|
func (r *AccountPostgresRepo) List(ctx context.Context, offset, limit int) ([]*entities.Account, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, username, email, phone, public_key, keygen_session_id,
|
SELECT id, username, email, phone, public_key, keygen_session_id,
|
||||||
threshold_n, threshold_t, status, created_at, updated_at, last_login_at
|
threshold_n, threshold_t, status, created_at, updated_at, last_login_at
|
||||||
FROM accounts
|
FROM accounts
|
||||||
ORDER BY created_at DESC
|
ORDER BY created_at DESC
|
||||||
LIMIT $1 OFFSET $2
|
LIMIT $1 OFFSET $2
|
||||||
`
|
`
|
||||||
|
|
||||||
rows, err := r.db.QueryContext(ctx, query, limit, offset)
|
rows, err := r.db.QueryContext(ctx, query, limit, offset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
var accounts []*entities.Account
|
var accounts []*entities.Account
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
account, err := r.scanAccountFromRows(rows)
|
account, err := r.scanAccountFromRows(rows)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
accounts = append(accounts, account)
|
accounts = append(accounts, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
return accounts, rows.Err()
|
return accounts, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count returns the total number of accounts
|
// Count returns the total number of accounts
|
||||||
func (r *AccountPostgresRepo) Count(ctx context.Context) (int64, error) {
|
func (r *AccountPostgresRepo) Count(ctx context.Context) (int64, error) {
|
||||||
query := `SELECT COUNT(*) FROM accounts`
|
query := `SELECT COUNT(*) FROM accounts`
|
||||||
|
|
||||||
var count int64
|
var count int64
|
||||||
err := r.db.QueryRowContext(ctx, query).Scan(&count)
|
err := r.db.QueryRowContext(ctx, query).Scan(&count)
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// scanAccount scans a single account row
|
// scanAccount scans a single account row
|
||||||
func (r *AccountPostgresRepo) scanAccount(row *sql.Row) (*entities.Account, error) {
|
func (r *AccountPostgresRepo) scanAccount(row *sql.Row) (*entities.Account, error) {
|
||||||
var (
|
var (
|
||||||
id uuid.UUID
|
id uuid.UUID
|
||||||
username string
|
username string
|
||||||
email sql.NullString
|
email sql.NullString
|
||||||
phone sql.NullString
|
phone sql.NullString
|
||||||
publicKey []byte
|
publicKey []byte
|
||||||
keygenSessionID uuid.UUID
|
keygenSessionID uuid.UUID
|
||||||
thresholdN int
|
thresholdN int
|
||||||
thresholdT int
|
thresholdT int
|
||||||
status string
|
status string
|
||||||
account entities.Account
|
account entities.Account
|
||||||
)
|
)
|
||||||
|
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&id,
|
&id,
|
||||||
&username,
|
&username,
|
||||||
&email,
|
&email,
|
||||||
&phone,
|
&phone,
|
||||||
&publicKey,
|
&publicKey,
|
||||||
&keygenSessionID,
|
&keygenSessionID,
|
||||||
&thresholdN,
|
&thresholdN,
|
||||||
&thresholdT,
|
&thresholdT,
|
||||||
&status,
|
&status,
|
||||||
&account.CreatedAt,
|
&account.CreatedAt,
|
||||||
&account.UpdatedAt,
|
&account.UpdatedAt,
|
||||||
&account.LastLoginAt,
|
&account.LastLoginAt,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, entities.ErrAccountNotFound
|
return nil, entities.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
account.ID = value_objects.AccountIDFromUUID(id)
|
account.ID = value_objects.AccountIDFromUUID(id)
|
||||||
account.Username = username
|
account.Username = username
|
||||||
if email.Valid {
|
if email.Valid {
|
||||||
account.Email = &email.String
|
account.Email = &email.String
|
||||||
}
|
}
|
||||||
if phone.Valid {
|
if phone.Valid {
|
||||||
account.Phone = &phone.String
|
account.Phone = &phone.String
|
||||||
}
|
}
|
||||||
account.PublicKey = publicKey
|
account.PublicKey = publicKey
|
||||||
account.KeygenSessionID = keygenSessionID
|
account.KeygenSessionID = keygenSessionID
|
||||||
account.ThresholdN = thresholdN
|
account.ThresholdN = thresholdN
|
||||||
account.ThresholdT = thresholdT
|
account.ThresholdT = thresholdT
|
||||||
account.Status = value_objects.AccountStatus(status)
|
account.Status = value_objects.AccountStatus(status)
|
||||||
|
|
||||||
return &account, nil
|
return &account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// scanAccountFromRows scans account from rows
|
// scanAccountFromRows scans account from rows
|
||||||
func (r *AccountPostgresRepo) scanAccountFromRows(rows *sql.Rows) (*entities.Account, error) {
|
func (r *AccountPostgresRepo) scanAccountFromRows(rows *sql.Rows) (*entities.Account, error) {
|
||||||
var (
|
var (
|
||||||
id uuid.UUID
|
id uuid.UUID
|
||||||
username string
|
username string
|
||||||
email sql.NullString
|
email sql.NullString
|
||||||
phone sql.NullString
|
phone sql.NullString
|
||||||
publicKey []byte
|
publicKey []byte
|
||||||
keygenSessionID uuid.UUID
|
keygenSessionID uuid.UUID
|
||||||
thresholdN int
|
thresholdN int
|
||||||
thresholdT int
|
thresholdT int
|
||||||
status string
|
status string
|
||||||
account entities.Account
|
account entities.Account
|
||||||
)
|
)
|
||||||
|
|
||||||
err := rows.Scan(
|
err := rows.Scan(
|
||||||
&id,
|
&id,
|
||||||
&username,
|
&username,
|
||||||
&email,
|
&email,
|
||||||
&phone,
|
&phone,
|
||||||
&publicKey,
|
&publicKey,
|
||||||
&keygenSessionID,
|
&keygenSessionID,
|
||||||
&thresholdN,
|
&thresholdN,
|
||||||
&thresholdT,
|
&thresholdT,
|
||||||
&status,
|
&status,
|
||||||
&account.CreatedAt,
|
&account.CreatedAt,
|
||||||
&account.UpdatedAt,
|
&account.UpdatedAt,
|
||||||
&account.LastLoginAt,
|
&account.LastLoginAt,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
account.ID = value_objects.AccountIDFromUUID(id)
|
account.ID = value_objects.AccountIDFromUUID(id)
|
||||||
account.Username = username
|
account.Username = username
|
||||||
if email.Valid {
|
if email.Valid {
|
||||||
account.Email = &email.String
|
account.Email = &email.String
|
||||||
}
|
}
|
||||||
if phone.Valid {
|
if phone.Valid {
|
||||||
account.Phone = &phone.String
|
account.Phone = &phone.String
|
||||||
}
|
}
|
||||||
account.PublicKey = publicKey
|
account.PublicKey = publicKey
|
||||||
account.KeygenSessionID = keygenSessionID
|
account.KeygenSessionID = keygenSessionID
|
||||||
account.ThresholdN = thresholdN
|
account.ThresholdN = thresholdN
|
||||||
account.ThresholdT = thresholdT
|
account.ThresholdT = thresholdT
|
||||||
account.Status = value_objects.AccountStatus(status)
|
account.Status = value_objects.AccountStatus(status)
|
||||||
|
|
||||||
return &account, nil
|
return &account, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,266 +1,266 @@
|
||||||
package postgres
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/repositories"
|
"github.com/rwadurian/mpc-system/services/account/domain/repositories"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RecoverySessionPostgresRepo implements RecoverySessionRepository using PostgreSQL
|
// RecoverySessionPostgresRepo implements RecoverySessionRepository using PostgreSQL
|
||||||
type RecoverySessionPostgresRepo struct {
|
type RecoverySessionPostgresRepo struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRecoverySessionPostgresRepo creates a new RecoverySessionPostgresRepo
|
// NewRecoverySessionPostgresRepo creates a new RecoverySessionPostgresRepo
|
||||||
func NewRecoverySessionPostgresRepo(db *sql.DB) repositories.RecoverySessionRepository {
|
func NewRecoverySessionPostgresRepo(db *sql.DB) repositories.RecoverySessionRepository {
|
||||||
return &RecoverySessionPostgresRepo{db: db}
|
return &RecoverySessionPostgresRepo{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create creates a new recovery session
|
// Create creates a new recovery session
|
||||||
func (r *RecoverySessionPostgresRepo) Create(ctx context.Context, session *entities.RecoverySession) error {
|
func (r *RecoverySessionPostgresRepo) Create(ctx context.Context, session *entities.RecoverySession) error {
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO account_recovery_sessions (id, account_id, recovery_type, old_share_type,
|
INSERT INTO account_recovery_sessions (id, account_id, recovery_type, old_share_type,
|
||||||
new_keygen_session_id, status, requested_at, completed_at)
|
new_keygen_session_id, status, requested_at, completed_at)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
`
|
`
|
||||||
|
|
||||||
var oldShareType *string
|
var oldShareType *string
|
||||||
if session.OldShareType != nil {
|
if session.OldShareType != nil {
|
||||||
s := session.OldShareType.String()
|
s := session.OldShareType.String()
|
||||||
oldShareType = &s
|
oldShareType = &s
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := r.db.ExecContext(ctx, query,
|
_, err := r.db.ExecContext(ctx, query,
|
||||||
session.ID,
|
session.ID,
|
||||||
session.AccountID.UUID(),
|
session.AccountID.UUID(),
|
||||||
session.RecoveryType.String(),
|
session.RecoveryType.String(),
|
||||||
oldShareType,
|
oldShareType,
|
||||||
session.NewKeygenSessionID,
|
session.NewKeygenSessionID,
|
||||||
session.Status.String(),
|
session.Status.String(),
|
||||||
session.RequestedAt,
|
session.RequestedAt,
|
||||||
session.CompletedAt,
|
session.CompletedAt,
|
||||||
)
|
)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByID retrieves a recovery session by ID
|
// GetByID retrieves a recovery session by ID
|
||||||
func (r *RecoverySessionPostgresRepo) GetByID(ctx context.Context, id string) (*entities.RecoverySession, error) {
|
func (r *RecoverySessionPostgresRepo) GetByID(ctx context.Context, id string) (*entities.RecoverySession, error) {
|
||||||
sessionID, err := uuid.Parse(id)
|
sessionID, err := uuid.Parse(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, entities.ErrRecoveryNotFound
|
return nil, entities.ErrRecoveryNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
SELECT id, account_id, recovery_type, old_share_type,
|
SELECT id, account_id, recovery_type, old_share_type,
|
||||||
new_keygen_session_id, status, requested_at, completed_at
|
new_keygen_session_id, status, requested_at, completed_at
|
||||||
FROM account_recovery_sessions
|
FROM account_recovery_sessions
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
return r.scanSession(r.db.QueryRowContext(ctx, query, sessionID))
|
return r.scanSession(r.db.QueryRowContext(ctx, query, sessionID))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByAccountID retrieves recovery sessions for an account
|
// GetByAccountID retrieves recovery sessions for an account
|
||||||
func (r *RecoverySessionPostgresRepo) GetByAccountID(ctx context.Context, accountID value_objects.AccountID) ([]*entities.RecoverySession, error) {
|
func (r *RecoverySessionPostgresRepo) GetByAccountID(ctx context.Context, accountID value_objects.AccountID) ([]*entities.RecoverySession, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, account_id, recovery_type, old_share_type,
|
SELECT id, account_id, recovery_type, old_share_type,
|
||||||
new_keygen_session_id, status, requested_at, completed_at
|
new_keygen_session_id, status, requested_at, completed_at
|
||||||
FROM account_recovery_sessions
|
FROM account_recovery_sessions
|
||||||
WHERE account_id = $1
|
WHERE account_id = $1
|
||||||
ORDER BY requested_at DESC
|
ORDER BY requested_at DESC
|
||||||
`
|
`
|
||||||
|
|
||||||
return r.querySessions(ctx, query, accountID.UUID())
|
return r.querySessions(ctx, query, accountID.UUID())
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveByAccountID retrieves active recovery sessions for an account
|
// GetActiveByAccountID retrieves active recovery sessions for an account
|
||||||
func (r *RecoverySessionPostgresRepo) GetActiveByAccountID(ctx context.Context, accountID value_objects.AccountID) (*entities.RecoverySession, error) {
|
func (r *RecoverySessionPostgresRepo) GetActiveByAccountID(ctx context.Context, accountID value_objects.AccountID) (*entities.RecoverySession, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, account_id, recovery_type, old_share_type,
|
SELECT id, account_id, recovery_type, old_share_type,
|
||||||
new_keygen_session_id, status, requested_at, completed_at
|
new_keygen_session_id, status, requested_at, completed_at
|
||||||
FROM account_recovery_sessions
|
FROM account_recovery_sessions
|
||||||
WHERE account_id = $1 AND status IN ('requested', 'in_progress')
|
WHERE account_id = $1 AND status IN ('requested', 'in_progress')
|
||||||
ORDER BY requested_at DESC
|
ORDER BY requested_at DESC
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
`
|
`
|
||||||
|
|
||||||
return r.scanSession(r.db.QueryRowContext(ctx, query, accountID.UUID()))
|
return r.scanSession(r.db.QueryRowContext(ctx, query, accountID.UUID()))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update updates a recovery session
|
// Update updates a recovery session
|
||||||
func (r *RecoverySessionPostgresRepo) Update(ctx context.Context, session *entities.RecoverySession) error {
|
func (r *RecoverySessionPostgresRepo) Update(ctx context.Context, session *entities.RecoverySession) error {
|
||||||
query := `
|
query := `
|
||||||
UPDATE account_recovery_sessions
|
UPDATE account_recovery_sessions
|
||||||
SET recovery_type = $2, old_share_type = $3, new_keygen_session_id = $4,
|
SET recovery_type = $2, old_share_type = $3, new_keygen_session_id = $4,
|
||||||
status = $5, completed_at = $6
|
status = $5, completed_at = $6
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
var oldShareType *string
|
var oldShareType *string
|
||||||
if session.OldShareType != nil {
|
if session.OldShareType != nil {
|
||||||
s := session.OldShareType.String()
|
s := session.OldShareType.String()
|
||||||
oldShareType = &s
|
oldShareType = &s
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := r.db.ExecContext(ctx, query,
|
result, err := r.db.ExecContext(ctx, query,
|
||||||
session.ID,
|
session.ID,
|
||||||
session.RecoveryType.String(),
|
session.RecoveryType.String(),
|
||||||
oldShareType,
|
oldShareType,
|
||||||
session.NewKeygenSessionID,
|
session.NewKeygenSessionID,
|
||||||
session.Status.String(),
|
session.Status.String(),
|
||||||
session.CompletedAt,
|
session.CompletedAt,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
rowsAffected, err := result.RowsAffected()
|
rowsAffected, err := result.RowsAffected()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if rowsAffected == 0 {
|
if rowsAffected == 0 {
|
||||||
return entities.ErrRecoveryNotFound
|
return entities.ErrRecoveryNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete deletes a recovery session
|
// Delete deletes a recovery session
|
||||||
func (r *RecoverySessionPostgresRepo) Delete(ctx context.Context, id string) error {
|
func (r *RecoverySessionPostgresRepo) Delete(ctx context.Context, id string) error {
|
||||||
sessionID, err := uuid.Parse(id)
|
sessionID, err := uuid.Parse(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return entities.ErrRecoveryNotFound
|
return entities.ErrRecoveryNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
query := `DELETE FROM account_recovery_sessions WHERE id = $1`
|
query := `DELETE FROM account_recovery_sessions WHERE id = $1`
|
||||||
|
|
||||||
result, err := r.db.ExecContext(ctx, query, sessionID)
|
result, err := r.db.ExecContext(ctx, query, sessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
rowsAffected, err := result.RowsAffected()
|
rowsAffected, err := result.RowsAffected()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if rowsAffected == 0 {
|
if rowsAffected == 0 {
|
||||||
return entities.ErrRecoveryNotFound
|
return entities.ErrRecoveryNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// scanSession scans a single recovery session row
|
// scanSession scans a single recovery session row
|
||||||
func (r *RecoverySessionPostgresRepo) scanSession(row *sql.Row) (*entities.RecoverySession, error) {
|
func (r *RecoverySessionPostgresRepo) scanSession(row *sql.Row) (*entities.RecoverySession, error) {
|
||||||
var (
|
var (
|
||||||
id uuid.UUID
|
id uuid.UUID
|
||||||
accountID uuid.UUID
|
accountID uuid.UUID
|
||||||
recoveryType string
|
recoveryType string
|
||||||
oldShareType sql.NullString
|
oldShareType sql.NullString
|
||||||
newKeygenSessionID sql.NullString
|
newKeygenSessionID sql.NullString
|
||||||
status string
|
status string
|
||||||
session entities.RecoverySession
|
session entities.RecoverySession
|
||||||
)
|
)
|
||||||
|
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&id,
|
&id,
|
||||||
&accountID,
|
&accountID,
|
||||||
&recoveryType,
|
&recoveryType,
|
||||||
&oldShareType,
|
&oldShareType,
|
||||||
&newKeygenSessionID,
|
&newKeygenSessionID,
|
||||||
&status,
|
&status,
|
||||||
&session.RequestedAt,
|
&session.RequestedAt,
|
||||||
&session.CompletedAt,
|
&session.CompletedAt,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, entities.ErrRecoveryNotFound
|
return nil, entities.ErrRecoveryNotFound
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
session.ID = id
|
session.ID = id
|
||||||
session.AccountID = value_objects.AccountIDFromUUID(accountID)
|
session.AccountID = value_objects.AccountIDFromUUID(accountID)
|
||||||
session.RecoveryType = value_objects.RecoveryType(recoveryType)
|
session.RecoveryType = value_objects.RecoveryType(recoveryType)
|
||||||
session.Status = value_objects.RecoveryStatus(status)
|
session.Status = value_objects.RecoveryStatus(status)
|
||||||
|
|
||||||
if oldShareType.Valid {
|
if oldShareType.Valid {
|
||||||
st := value_objects.ShareType(oldShareType.String)
|
st := value_objects.ShareType(oldShareType.String)
|
||||||
session.OldShareType = &st
|
session.OldShareType = &st
|
||||||
}
|
}
|
||||||
|
|
||||||
if newKeygenSessionID.Valid {
|
if newKeygenSessionID.Valid {
|
||||||
if keygenID, err := uuid.Parse(newKeygenSessionID.String); err == nil {
|
if keygenID, err := uuid.Parse(newKeygenSessionID.String); err == nil {
|
||||||
session.NewKeygenSessionID = &keygenID
|
session.NewKeygenSessionID = &keygenID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &session, nil
|
return &session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// querySessions queries multiple recovery sessions
|
// querySessions queries multiple recovery sessions
|
||||||
func (r *RecoverySessionPostgresRepo) querySessions(ctx context.Context, query string, args ...interface{}) ([]*entities.RecoverySession, error) {
|
func (r *RecoverySessionPostgresRepo) querySessions(ctx context.Context, query string, args ...interface{}) ([]*entities.RecoverySession, error) {
|
||||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
var sessions []*entities.RecoverySession
|
var sessions []*entities.RecoverySession
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
id uuid.UUID
|
id uuid.UUID
|
||||||
accountID uuid.UUID
|
accountID uuid.UUID
|
||||||
recoveryType string
|
recoveryType string
|
||||||
oldShareType sql.NullString
|
oldShareType sql.NullString
|
||||||
newKeygenSessionID sql.NullString
|
newKeygenSessionID sql.NullString
|
||||||
status string
|
status string
|
||||||
session entities.RecoverySession
|
session entities.RecoverySession
|
||||||
)
|
)
|
||||||
|
|
||||||
err := rows.Scan(
|
err := rows.Scan(
|
||||||
&id,
|
&id,
|
||||||
&accountID,
|
&accountID,
|
||||||
&recoveryType,
|
&recoveryType,
|
||||||
&oldShareType,
|
&oldShareType,
|
||||||
&newKeygenSessionID,
|
&newKeygenSessionID,
|
||||||
&status,
|
&status,
|
||||||
&session.RequestedAt,
|
&session.RequestedAt,
|
||||||
&session.CompletedAt,
|
&session.CompletedAt,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
session.ID = id
|
session.ID = id
|
||||||
session.AccountID = value_objects.AccountIDFromUUID(accountID)
|
session.AccountID = value_objects.AccountIDFromUUID(accountID)
|
||||||
session.RecoveryType = value_objects.RecoveryType(recoveryType)
|
session.RecoveryType = value_objects.RecoveryType(recoveryType)
|
||||||
session.Status = value_objects.RecoveryStatus(status)
|
session.Status = value_objects.RecoveryStatus(status)
|
||||||
|
|
||||||
if oldShareType.Valid {
|
if oldShareType.Valid {
|
||||||
st := value_objects.ShareType(oldShareType.String)
|
st := value_objects.ShareType(oldShareType.String)
|
||||||
session.OldShareType = &st
|
session.OldShareType = &st
|
||||||
}
|
}
|
||||||
|
|
||||||
if newKeygenSessionID.Valid {
|
if newKeygenSessionID.Valid {
|
||||||
if keygenID, err := uuid.Parse(newKeygenSessionID.String); err == nil {
|
if keygenID, err := uuid.Parse(newKeygenSessionID.String); err == nil {
|
||||||
session.NewKeygenSessionID = &keygenID
|
session.NewKeygenSessionID = &keygenID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sessions = append(sessions, &session)
|
sessions = append(sessions, &session)
|
||||||
}
|
}
|
||||||
|
|
||||||
return sessions, rows.Err()
|
return sessions, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,284 +1,284 @@
|
||||||
package postgres
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/repositories"
|
"github.com/rwadurian/mpc-system/services/account/domain/repositories"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AccountSharePostgresRepo implements AccountShareRepository using PostgreSQL
|
// AccountSharePostgresRepo implements AccountShareRepository using PostgreSQL
|
||||||
type AccountSharePostgresRepo struct {
|
type AccountSharePostgresRepo struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountSharePostgresRepo creates a new AccountSharePostgresRepo
|
// NewAccountSharePostgresRepo creates a new AccountSharePostgresRepo
|
||||||
func NewAccountSharePostgresRepo(db *sql.DB) repositories.AccountShareRepository {
|
func NewAccountSharePostgresRepo(db *sql.DB) repositories.AccountShareRepository {
|
||||||
return &AccountSharePostgresRepo{db: db}
|
return &AccountSharePostgresRepo{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create creates a new account share
|
// Create creates a new account share
|
||||||
func (r *AccountSharePostgresRepo) Create(ctx context.Context, share *entities.AccountShare) error {
|
func (r *AccountSharePostgresRepo) Create(ctx context.Context, share *entities.AccountShare) error {
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO account_shares (id, account_id, share_type, party_id, party_index,
|
INSERT INTO account_shares (id, account_id, share_type, party_id, party_index,
|
||||||
device_type, device_id, created_at, last_used_at, is_active)
|
device_type, device_id, created_at, last_used_at, is_active)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||||
`
|
`
|
||||||
|
|
||||||
_, err := r.db.ExecContext(ctx, query,
|
_, err := r.db.ExecContext(ctx, query,
|
||||||
share.ID,
|
share.ID,
|
||||||
share.AccountID.UUID(),
|
share.AccountID.UUID(),
|
||||||
share.ShareType.String(),
|
share.ShareType.String(),
|
||||||
share.PartyID,
|
share.PartyID,
|
||||||
share.PartyIndex,
|
share.PartyIndex,
|
||||||
share.DeviceType,
|
share.DeviceType,
|
||||||
share.DeviceID,
|
share.DeviceID,
|
||||||
share.CreatedAt,
|
share.CreatedAt,
|
||||||
share.LastUsedAt,
|
share.LastUsedAt,
|
||||||
share.IsActive,
|
share.IsActive,
|
||||||
)
|
)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByID retrieves a share by ID
|
// GetByID retrieves a share by ID
|
||||||
func (r *AccountSharePostgresRepo) GetByID(ctx context.Context, id string) (*entities.AccountShare, error) {
|
func (r *AccountSharePostgresRepo) GetByID(ctx context.Context, id string) (*entities.AccountShare, error) {
|
||||||
shareID, err := uuid.Parse(id)
|
shareID, err := uuid.Parse(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, entities.ErrShareNotFound
|
return nil, entities.ErrShareNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
SELECT id, account_id, share_type, party_id, party_index,
|
SELECT id, account_id, share_type, party_id, party_index,
|
||||||
device_type, device_id, created_at, last_used_at, is_active
|
device_type, device_id, created_at, last_used_at, is_active
|
||||||
FROM account_shares
|
FROM account_shares
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
return r.scanShare(r.db.QueryRowContext(ctx, query, shareID))
|
return r.scanShare(r.db.QueryRowContext(ctx, query, shareID))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByAccountID retrieves all shares for an account
|
// GetByAccountID retrieves all shares for an account
|
||||||
func (r *AccountSharePostgresRepo) GetByAccountID(ctx context.Context, accountID value_objects.AccountID) ([]*entities.AccountShare, error) {
|
func (r *AccountSharePostgresRepo) GetByAccountID(ctx context.Context, accountID value_objects.AccountID) ([]*entities.AccountShare, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, account_id, share_type, party_id, party_index,
|
SELECT id, account_id, share_type, party_id, party_index,
|
||||||
device_type, device_id, created_at, last_used_at, is_active
|
device_type, device_id, created_at, last_used_at, is_active
|
||||||
FROM account_shares
|
FROM account_shares
|
||||||
WHERE account_id = $1
|
WHERE account_id = $1
|
||||||
ORDER BY party_index
|
ORDER BY party_index
|
||||||
`
|
`
|
||||||
|
|
||||||
return r.queryShares(ctx, query, accountID.UUID())
|
return r.queryShares(ctx, query, accountID.UUID())
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveByAccountID retrieves active shares for an account
|
// GetActiveByAccountID retrieves active shares for an account
|
||||||
func (r *AccountSharePostgresRepo) GetActiveByAccountID(ctx context.Context, accountID value_objects.AccountID) ([]*entities.AccountShare, error) {
|
func (r *AccountSharePostgresRepo) GetActiveByAccountID(ctx context.Context, accountID value_objects.AccountID) ([]*entities.AccountShare, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, account_id, share_type, party_id, party_index,
|
SELECT id, account_id, share_type, party_id, party_index,
|
||||||
device_type, device_id, created_at, last_used_at, is_active
|
device_type, device_id, created_at, last_used_at, is_active
|
||||||
FROM account_shares
|
FROM account_shares
|
||||||
WHERE account_id = $1 AND is_active = TRUE
|
WHERE account_id = $1 AND is_active = TRUE
|
||||||
ORDER BY party_index
|
ORDER BY party_index
|
||||||
`
|
`
|
||||||
|
|
||||||
return r.queryShares(ctx, query, accountID.UUID())
|
return r.queryShares(ctx, query, accountID.UUID())
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByPartyID retrieves shares by party ID
|
// GetByPartyID retrieves shares by party ID
|
||||||
func (r *AccountSharePostgresRepo) GetByPartyID(ctx context.Context, partyID string) ([]*entities.AccountShare, error) {
|
func (r *AccountSharePostgresRepo) GetByPartyID(ctx context.Context, partyID string) ([]*entities.AccountShare, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, account_id, share_type, party_id, party_index,
|
SELECT id, account_id, share_type, party_id, party_index,
|
||||||
device_type, device_id, created_at, last_used_at, is_active
|
device_type, device_id, created_at, last_used_at, is_active
|
||||||
FROM account_shares
|
FROM account_shares
|
||||||
WHERE party_id = $1
|
WHERE party_id = $1
|
||||||
ORDER BY created_at DESC
|
ORDER BY created_at DESC
|
||||||
`
|
`
|
||||||
|
|
||||||
return r.queryShares(ctx, query, partyID)
|
return r.queryShares(ctx, query, partyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update updates a share
|
// Update updates a share
|
||||||
func (r *AccountSharePostgresRepo) Update(ctx context.Context, share *entities.AccountShare) error {
|
func (r *AccountSharePostgresRepo) Update(ctx context.Context, share *entities.AccountShare) error {
|
||||||
query := `
|
query := `
|
||||||
UPDATE account_shares
|
UPDATE account_shares
|
||||||
SET share_type = $2, party_id = $3, party_index = $4,
|
SET share_type = $2, party_id = $3, party_index = $4,
|
||||||
device_type = $5, device_id = $6, last_used_at = $7, is_active = $8
|
device_type = $5, device_id = $6, last_used_at = $7, is_active = $8
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
result, err := r.db.ExecContext(ctx, query,
|
result, err := r.db.ExecContext(ctx, query,
|
||||||
share.ID,
|
share.ID,
|
||||||
share.ShareType.String(),
|
share.ShareType.String(),
|
||||||
share.PartyID,
|
share.PartyID,
|
||||||
share.PartyIndex,
|
share.PartyIndex,
|
||||||
share.DeviceType,
|
share.DeviceType,
|
||||||
share.DeviceID,
|
share.DeviceID,
|
||||||
share.LastUsedAt,
|
share.LastUsedAt,
|
||||||
share.IsActive,
|
share.IsActive,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
rowsAffected, err := result.RowsAffected()
|
rowsAffected, err := result.RowsAffected()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if rowsAffected == 0 {
|
if rowsAffected == 0 {
|
||||||
return entities.ErrShareNotFound
|
return entities.ErrShareNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete deletes a share
|
// Delete deletes a share
|
||||||
func (r *AccountSharePostgresRepo) Delete(ctx context.Context, id string) error {
|
func (r *AccountSharePostgresRepo) Delete(ctx context.Context, id string) error {
|
||||||
shareID, err := uuid.Parse(id)
|
shareID, err := uuid.Parse(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return entities.ErrShareNotFound
|
return entities.ErrShareNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
query := `DELETE FROM account_shares WHERE id = $1`
|
query := `DELETE FROM account_shares WHERE id = $1`
|
||||||
|
|
||||||
result, err := r.db.ExecContext(ctx, query, shareID)
|
result, err := r.db.ExecContext(ctx, query, shareID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
rowsAffected, err := result.RowsAffected()
|
rowsAffected, err := result.RowsAffected()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if rowsAffected == 0 {
|
if rowsAffected == 0 {
|
||||||
return entities.ErrShareNotFound
|
return entities.ErrShareNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeactivateByAccountID deactivates all shares for an account
|
// DeactivateByAccountID deactivates all shares for an account
|
||||||
func (r *AccountSharePostgresRepo) DeactivateByAccountID(ctx context.Context, accountID value_objects.AccountID) error {
|
func (r *AccountSharePostgresRepo) DeactivateByAccountID(ctx context.Context, accountID value_objects.AccountID) error {
|
||||||
query := `UPDATE account_shares SET is_active = FALSE WHERE account_id = $1`
|
query := `UPDATE account_shares SET is_active = FALSE WHERE account_id = $1`
|
||||||
|
|
||||||
_, err := r.db.ExecContext(ctx, query, accountID.UUID())
|
_, err := r.db.ExecContext(ctx, query, accountID.UUID())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeactivateByShareType deactivates shares of a specific type for an account
|
// DeactivateByShareType deactivates shares of a specific type for an account
|
||||||
func (r *AccountSharePostgresRepo) DeactivateByShareType(ctx context.Context, accountID value_objects.AccountID, shareType value_objects.ShareType) error {
|
func (r *AccountSharePostgresRepo) DeactivateByShareType(ctx context.Context, accountID value_objects.AccountID, shareType value_objects.ShareType) error {
|
||||||
query := `UPDATE account_shares SET is_active = FALSE WHERE account_id = $1 AND share_type = $2`
|
query := `UPDATE account_shares SET is_active = FALSE WHERE account_id = $1 AND share_type = $2`
|
||||||
|
|
||||||
_, err := r.db.ExecContext(ctx, query, accountID.UUID(), shareType.String())
|
_, err := r.db.ExecContext(ctx, query, accountID.UUID(), shareType.String())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// scanShare scans a single share row
|
// scanShare scans a single share row
|
||||||
func (r *AccountSharePostgresRepo) scanShare(row *sql.Row) (*entities.AccountShare, error) {
|
func (r *AccountSharePostgresRepo) scanShare(row *sql.Row) (*entities.AccountShare, error) {
|
||||||
var (
|
var (
|
||||||
id uuid.UUID
|
id uuid.UUID
|
||||||
accountID uuid.UUID
|
accountID uuid.UUID
|
||||||
shareType string
|
shareType string
|
||||||
partyID string
|
partyID string
|
||||||
partyIndex int
|
partyIndex int
|
||||||
deviceType sql.NullString
|
deviceType sql.NullString
|
||||||
deviceID sql.NullString
|
deviceID sql.NullString
|
||||||
share entities.AccountShare
|
share entities.AccountShare
|
||||||
)
|
)
|
||||||
|
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&id,
|
&id,
|
||||||
&accountID,
|
&accountID,
|
||||||
&shareType,
|
&shareType,
|
||||||
&partyID,
|
&partyID,
|
||||||
&partyIndex,
|
&partyIndex,
|
||||||
&deviceType,
|
&deviceType,
|
||||||
&deviceID,
|
&deviceID,
|
||||||
&share.CreatedAt,
|
&share.CreatedAt,
|
||||||
&share.LastUsedAt,
|
&share.LastUsedAt,
|
||||||
&share.IsActive,
|
&share.IsActive,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, entities.ErrShareNotFound
|
return nil, entities.ErrShareNotFound
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
share.ID = id
|
share.ID = id
|
||||||
share.AccountID = value_objects.AccountIDFromUUID(accountID)
|
share.AccountID = value_objects.AccountIDFromUUID(accountID)
|
||||||
share.ShareType = value_objects.ShareType(shareType)
|
share.ShareType = value_objects.ShareType(shareType)
|
||||||
share.PartyID = partyID
|
share.PartyID = partyID
|
||||||
share.PartyIndex = partyIndex
|
share.PartyIndex = partyIndex
|
||||||
if deviceType.Valid {
|
if deviceType.Valid {
|
||||||
share.DeviceType = &deviceType.String
|
share.DeviceType = &deviceType.String
|
||||||
}
|
}
|
||||||
if deviceID.Valid {
|
if deviceID.Valid {
|
||||||
share.DeviceID = &deviceID.String
|
share.DeviceID = &deviceID.String
|
||||||
}
|
}
|
||||||
|
|
||||||
return &share, nil
|
return &share, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// queryShares queries multiple shares
|
// queryShares queries multiple shares
|
||||||
func (r *AccountSharePostgresRepo) queryShares(ctx context.Context, query string, args ...interface{}) ([]*entities.AccountShare, error) {
|
func (r *AccountSharePostgresRepo) queryShares(ctx context.Context, query string, args ...interface{}) ([]*entities.AccountShare, error) {
|
||||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
var shares []*entities.AccountShare
|
var shares []*entities.AccountShare
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
id uuid.UUID
|
id uuid.UUID
|
||||||
accountID uuid.UUID
|
accountID uuid.UUID
|
||||||
shareType string
|
shareType string
|
||||||
partyID string
|
partyID string
|
||||||
partyIndex int
|
partyIndex int
|
||||||
deviceType sql.NullString
|
deviceType sql.NullString
|
||||||
deviceID sql.NullString
|
deviceID sql.NullString
|
||||||
share entities.AccountShare
|
share entities.AccountShare
|
||||||
)
|
)
|
||||||
|
|
||||||
err := rows.Scan(
|
err := rows.Scan(
|
||||||
&id,
|
&id,
|
||||||
&accountID,
|
&accountID,
|
||||||
&shareType,
|
&shareType,
|
||||||
&partyID,
|
&partyID,
|
||||||
&partyIndex,
|
&partyIndex,
|
||||||
&deviceType,
|
&deviceType,
|
||||||
&deviceID,
|
&deviceID,
|
||||||
&share.CreatedAt,
|
&share.CreatedAt,
|
||||||
&share.LastUsedAt,
|
&share.LastUsedAt,
|
||||||
&share.IsActive,
|
&share.IsActive,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
share.ID = id
|
share.ID = id
|
||||||
share.AccountID = value_objects.AccountIDFromUUID(accountID)
|
share.AccountID = value_objects.AccountIDFromUUID(accountID)
|
||||||
share.ShareType = value_objects.ShareType(shareType)
|
share.ShareType = value_objects.ShareType(shareType)
|
||||||
share.PartyID = partyID
|
share.PartyID = partyID
|
||||||
share.PartyIndex = partyIndex
|
share.PartyIndex = partyIndex
|
||||||
if deviceType.Valid {
|
if deviceType.Valid {
|
||||||
share.DeviceType = &deviceType.String
|
share.DeviceType = &deviceType.String
|
||||||
}
|
}
|
||||||
if deviceID.Valid {
|
if deviceID.Valid {
|
||||||
share.DeviceID = &deviceID.String
|
share.DeviceID = &deviceID.String
|
||||||
}
|
}
|
||||||
|
|
||||||
shares = append(shares, &share)
|
shares = append(shares, &share)
|
||||||
}
|
}
|
||||||
|
|
||||||
return shares, rows.Err()
|
return shares, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,80 +1,80 @@
|
||||||
package rabbitmq
|
package rabbitmq
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
amqp "github.com/rabbitmq/amqp091-go"
|
amqp "github.com/rabbitmq/amqp091-go"
|
||||||
"github.com/rwadurian/mpc-system/services/account/application/ports"
|
"github.com/rwadurian/mpc-system/services/account/application/ports"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
exchangeName = "account.events"
|
exchangeName = "account.events"
|
||||||
exchangeType = "topic"
|
exchangeType = "topic"
|
||||||
)
|
)
|
||||||
|
|
||||||
// EventPublisherAdapter implements EventPublisher using RabbitMQ
|
// EventPublisherAdapter implements EventPublisher using RabbitMQ
|
||||||
type EventPublisherAdapter struct {
|
type EventPublisherAdapter struct {
|
||||||
conn *amqp.Connection
|
conn *amqp.Connection
|
||||||
channel *amqp.Channel
|
channel *amqp.Channel
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewEventPublisherAdapter creates a new EventPublisherAdapter
|
// NewEventPublisherAdapter creates a new EventPublisherAdapter
|
||||||
func NewEventPublisherAdapter(conn *amqp.Connection) (*EventPublisherAdapter, error) {
|
func NewEventPublisherAdapter(conn *amqp.Connection) (*EventPublisherAdapter, error) {
|
||||||
channel, err := conn.Channel()
|
channel, err := conn.Channel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Declare exchange
|
// Declare exchange
|
||||||
err = channel.ExchangeDeclare(
|
err = channel.ExchangeDeclare(
|
||||||
exchangeName,
|
exchangeName,
|
||||||
exchangeType,
|
exchangeType,
|
||||||
true, // durable
|
true, // durable
|
||||||
false, // auto-deleted
|
false, // auto-deleted
|
||||||
false, // internal
|
false, // internal
|
||||||
false, // no-wait
|
false, // no-wait
|
||||||
nil, // arguments
|
nil, // arguments
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
channel.Close()
|
channel.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &EventPublisherAdapter{
|
return &EventPublisherAdapter{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
channel: channel,
|
channel: channel,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish publishes an account event
|
// Publish publishes an account event
|
||||||
func (p *EventPublisherAdapter) Publish(ctx context.Context, event ports.AccountEvent) error {
|
func (p *EventPublisherAdapter) Publish(ctx context.Context, event ports.AccountEvent) error {
|
||||||
body, err := json.Marshal(event)
|
body, err := json.Marshal(event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
routingKey := string(event.Type)
|
routingKey := string(event.Type)
|
||||||
|
|
||||||
return p.channel.PublishWithContext(ctx,
|
return p.channel.PublishWithContext(ctx,
|
||||||
exchangeName,
|
exchangeName,
|
||||||
routingKey,
|
routingKey,
|
||||||
false, // mandatory
|
false, // mandatory
|
||||||
false, // immediate
|
false, // immediate
|
||||||
amqp.Publishing{
|
amqp.Publishing{
|
||||||
ContentType: "application/json",
|
ContentType: "application/json",
|
||||||
DeliveryMode: amqp.Persistent,
|
DeliveryMode: amqp.Persistent,
|
||||||
Timestamp: time.Now().UTC(),
|
Timestamp: time.Now().UTC(),
|
||||||
Body: body,
|
Body: body,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the publisher
|
// Close closes the publisher
|
||||||
func (p *EventPublisherAdapter) Close() error {
|
func (p *EventPublisherAdapter) Close() error {
|
||||||
if p.channel != nil {
|
if p.channel != nil {
|
||||||
return p.channel.Close()
|
return p.channel.Close()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,181 +1,181 @@
|
||||||
package redis
|
package redis
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
"github.com/rwadurian/mpc-system/services/account/application/ports"
|
"github.com/rwadurian/mpc-system/services/account/application/ports"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CacheAdapter implements CacheService using Redis
|
// CacheAdapter implements CacheService using Redis
|
||||||
type CacheAdapter struct {
|
type CacheAdapter struct {
|
||||||
client *redis.Client
|
client *redis.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCacheAdapter creates a new CacheAdapter
|
// NewCacheAdapter creates a new CacheAdapter
|
||||||
func NewCacheAdapter(client *redis.Client) ports.CacheService {
|
func NewCacheAdapter(client *redis.Client) ports.CacheService {
|
||||||
return &CacheAdapter{client: client}
|
return &CacheAdapter{client: client}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set sets a value in the cache
|
// Set sets a value in the cache
|
||||||
func (c *CacheAdapter) Set(ctx context.Context, key string, value interface{}, ttlSeconds int) error {
|
func (c *CacheAdapter) Set(ctx context.Context, key string, value interface{}, ttlSeconds int) error {
|
||||||
data, err := json.Marshal(value)
|
data, err := json.Marshal(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.client.Set(ctx, key, data, time.Duration(ttlSeconds)*time.Second).Err()
|
return c.client.Set(ctx, key, data, time.Duration(ttlSeconds)*time.Second).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get gets a value from the cache
|
// Get gets a value from the cache
|
||||||
func (c *CacheAdapter) Get(ctx context.Context, key string) (interface{}, error) {
|
func (c *CacheAdapter) Get(ctx context.Context, key string) (interface{}, error) {
|
||||||
data, err := c.client.Get(ctx, key).Bytes()
|
data, err := c.client.Get(ctx, key).Bytes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == redis.Nil {
|
if err == redis.Nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var value interface{}
|
var value interface{}
|
||||||
if err := json.Unmarshal(data, &value); err != nil {
|
if err := json.Unmarshal(data, &value); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return value, nil
|
return value, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete deletes a value from the cache
|
// Delete deletes a value from the cache
|
||||||
func (c *CacheAdapter) Delete(ctx context.Context, key string) error {
|
func (c *CacheAdapter) Delete(ctx context.Context, key string) error {
|
||||||
return c.client.Del(ctx, key).Err()
|
return c.client.Del(ctx, key).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exists checks if a key exists in the cache
|
// Exists checks if a key exists in the cache
|
||||||
func (c *CacheAdapter) Exists(ctx context.Context, key string) (bool, error) {
|
func (c *CacheAdapter) Exists(ctx context.Context, key string) (bool, error) {
|
||||||
result, err := c.client.Exists(ctx, key).Result()
|
result, err := c.client.Exists(ctx, key).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
return result > 0, nil
|
return result > 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountCacheAdapter provides account-specific caching
|
// AccountCacheAdapter provides account-specific caching
|
||||||
type AccountCacheAdapter struct {
|
type AccountCacheAdapter struct {
|
||||||
client *redis.Client
|
client *redis.Client
|
||||||
keyPrefix string
|
keyPrefix string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountCacheAdapter creates a new AccountCacheAdapter
|
// NewAccountCacheAdapter creates a new AccountCacheAdapter
|
||||||
func NewAccountCacheAdapter(client *redis.Client) *AccountCacheAdapter {
|
func NewAccountCacheAdapter(client *redis.Client) *AccountCacheAdapter {
|
||||||
return &AccountCacheAdapter{
|
return &AccountCacheAdapter{
|
||||||
client: client,
|
client: client,
|
||||||
keyPrefix: "account:",
|
keyPrefix: "account:",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CacheAccount caches an account
|
// CacheAccount caches an account
|
||||||
func (c *AccountCacheAdapter) CacheAccount(ctx context.Context, accountID string, data interface{}, ttl time.Duration) error {
|
func (c *AccountCacheAdapter) CacheAccount(ctx context.Context, accountID string, data interface{}, ttl time.Duration) error {
|
||||||
key := c.keyPrefix + accountID
|
key := c.keyPrefix + accountID
|
||||||
jsonData, err := json.Marshal(data)
|
jsonData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.client.Set(ctx, key, jsonData, ttl).Err()
|
return c.client.Set(ctx, key, jsonData, ttl).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCachedAccount gets a cached account
|
// GetCachedAccount gets a cached account
|
||||||
func (c *AccountCacheAdapter) GetCachedAccount(ctx context.Context, accountID string) (map[string]interface{}, error) {
|
func (c *AccountCacheAdapter) GetCachedAccount(ctx context.Context, accountID string) (map[string]interface{}, error) {
|
||||||
key := c.keyPrefix + accountID
|
key := c.keyPrefix + accountID
|
||||||
data, err := c.client.Get(ctx, key).Bytes()
|
data, err := c.client.Get(ctx, key).Bytes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == redis.Nil {
|
if err == redis.Nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var result map[string]interface{}
|
var result map[string]interface{}
|
||||||
if err := json.Unmarshal(data, &result); err != nil {
|
if err := json.Unmarshal(data, &result); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// InvalidateAccount invalidates cached account data
|
// InvalidateAccount invalidates cached account data
|
||||||
func (c *AccountCacheAdapter) InvalidateAccount(ctx context.Context, accountID string) error {
|
func (c *AccountCacheAdapter) InvalidateAccount(ctx context.Context, accountID string) error {
|
||||||
key := c.keyPrefix + accountID
|
key := c.keyPrefix + accountID
|
||||||
return c.client.Del(ctx, key).Err()
|
return c.client.Del(ctx, key).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CacheLoginChallenge caches a login challenge
|
// CacheLoginChallenge caches a login challenge
|
||||||
func (c *AccountCacheAdapter) CacheLoginChallenge(ctx context.Context, challengeID string, data map[string]interface{}) error {
|
func (c *AccountCacheAdapter) CacheLoginChallenge(ctx context.Context, challengeID string, data map[string]interface{}) error {
|
||||||
key := "login_challenge:" + challengeID
|
key := "login_challenge:" + challengeID
|
||||||
jsonData, err := json.Marshal(data)
|
jsonData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.client.Set(ctx, key, jsonData, 5*time.Minute).Err()
|
return c.client.Set(ctx, key, jsonData, 5*time.Minute).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLoginChallenge gets a login challenge
|
// GetLoginChallenge gets a login challenge
|
||||||
func (c *AccountCacheAdapter) GetLoginChallenge(ctx context.Context, challengeID string) (map[string]interface{}, error) {
|
func (c *AccountCacheAdapter) GetLoginChallenge(ctx context.Context, challengeID string) (map[string]interface{}, error) {
|
||||||
key := "login_challenge:" + challengeID
|
key := "login_challenge:" + challengeID
|
||||||
data, err := c.client.Get(ctx, key).Bytes()
|
data, err := c.client.Get(ctx, key).Bytes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == redis.Nil {
|
if err == redis.Nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var result map[string]interface{}
|
var result map[string]interface{}
|
||||||
if err := json.Unmarshal(data, &result); err != nil {
|
if err := json.Unmarshal(data, &result); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteLoginChallenge deletes a login challenge after use
|
// DeleteLoginChallenge deletes a login challenge after use
|
||||||
func (c *AccountCacheAdapter) DeleteLoginChallenge(ctx context.Context, challengeID string) error {
|
func (c *AccountCacheAdapter) DeleteLoginChallenge(ctx context.Context, challengeID string) error {
|
||||||
key := "login_challenge:" + challengeID
|
key := "login_challenge:" + challengeID
|
||||||
return c.client.Del(ctx, key).Err()
|
return c.client.Del(ctx, key).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IncrementLoginAttempts increments failed login attempts
|
// IncrementLoginAttempts increments failed login attempts
|
||||||
func (c *AccountCacheAdapter) IncrementLoginAttempts(ctx context.Context, username string) (int64, error) {
|
func (c *AccountCacheAdapter) IncrementLoginAttempts(ctx context.Context, username string) (int64, error) {
|
||||||
key := "login_attempts:" + username
|
key := "login_attempts:" + username
|
||||||
count, err := c.client.Incr(ctx, key).Result()
|
count, err := c.client.Incr(ctx, key).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set expiry on first attempt
|
// Set expiry on first attempt
|
||||||
if count == 1 {
|
if count == 1 {
|
||||||
c.client.Expire(ctx, key, 15*time.Minute)
|
c.client.Expire(ctx, key, 15*time.Minute)
|
||||||
}
|
}
|
||||||
|
|
||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLoginAttempts gets the current login attempt count
|
// GetLoginAttempts gets the current login attempt count
|
||||||
func (c *AccountCacheAdapter) GetLoginAttempts(ctx context.Context, username string) (int64, error) {
|
func (c *AccountCacheAdapter) GetLoginAttempts(ctx context.Context, username string) (int64, error) {
|
||||||
key := "login_attempts:" + username
|
key := "login_attempts:" + username
|
||||||
count, err := c.client.Get(ctx, key).Int64()
|
count, err := c.client.Get(ctx, key).Int64()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == redis.Nil {
|
if err == redis.Nil {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResetLoginAttempts resets login attempts after successful login
|
// ResetLoginAttempts resets login attempts after successful login
|
||||||
func (c *AccountCacheAdapter) ResetLoginAttempts(ctx context.Context, username string) error {
|
func (c *AccountCacheAdapter) ResetLoginAttempts(ctx context.Context, username string) error {
|
||||||
key := "login_attempts:" + username
|
key := "login_attempts:" + username
|
||||||
return c.client.Del(ctx, key).Err()
|
return c.client.Del(ctx, key).Err()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,140 +1,140 @@
|
||||||
package ports
|
package ports
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateAccountInput represents input for creating an account
|
// CreateAccountInput represents input for creating an account
|
||||||
type CreateAccountInput struct {
|
type CreateAccountInput struct {
|
||||||
Username string
|
Username string
|
||||||
Email string
|
Email string
|
||||||
Phone *string
|
Phone *string
|
||||||
PublicKey []byte
|
PublicKey []byte
|
||||||
KeygenSessionID uuid.UUID
|
KeygenSessionID uuid.UUID
|
||||||
ThresholdN int
|
ThresholdN int
|
||||||
ThresholdT int
|
ThresholdT int
|
||||||
Shares []ShareInput
|
Shares []ShareInput
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShareInput represents input for a key share
|
// ShareInput represents input for a key share
|
||||||
type ShareInput struct {
|
type ShareInput struct {
|
||||||
ShareType value_objects.ShareType
|
ShareType value_objects.ShareType
|
||||||
PartyID string
|
PartyID string
|
||||||
PartyIndex int
|
PartyIndex int
|
||||||
DeviceType *string
|
DeviceType *string
|
||||||
DeviceID *string
|
DeviceID *string
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAccountOutput represents output from creating an account
|
// CreateAccountOutput represents output from creating an account
|
||||||
type CreateAccountOutput struct {
|
type CreateAccountOutput struct {
|
||||||
Account *entities.Account
|
Account *entities.Account
|
||||||
Shares []*entities.AccountShare
|
Shares []*entities.AccountShare
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAccountPort defines the input port for creating accounts
|
// CreateAccountPort defines the input port for creating accounts
|
||||||
type CreateAccountPort interface {
|
type CreateAccountPort interface {
|
||||||
Execute(ctx context.Context, input CreateAccountInput) (*CreateAccountOutput, error)
|
Execute(ctx context.Context, input CreateAccountInput) (*CreateAccountOutput, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountInput represents input for getting an account
|
// GetAccountInput represents input for getting an account
|
||||||
type GetAccountInput struct {
|
type GetAccountInput struct {
|
||||||
AccountID *value_objects.AccountID
|
AccountID *value_objects.AccountID
|
||||||
Username *string
|
Username *string
|
||||||
Email *string
|
Email *string
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountOutput represents output from getting an account
|
// GetAccountOutput represents output from getting an account
|
||||||
type GetAccountOutput struct {
|
type GetAccountOutput struct {
|
||||||
Account *entities.Account
|
Account *entities.Account
|
||||||
Shares []*entities.AccountShare
|
Shares []*entities.AccountShare
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountPort defines the input port for getting accounts
|
// GetAccountPort defines the input port for getting accounts
|
||||||
type GetAccountPort interface {
|
type GetAccountPort interface {
|
||||||
Execute(ctx context.Context, input GetAccountInput) (*GetAccountOutput, error)
|
Execute(ctx context.Context, input GetAccountInput) (*GetAccountOutput, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginInput represents input for login
|
// LoginInput represents input for login
|
||||||
type LoginInput struct {
|
type LoginInput struct {
|
||||||
Username string
|
Username string
|
||||||
Challenge []byte
|
Challenge []byte
|
||||||
Signature []byte
|
Signature []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginOutput represents output from login
|
// LoginOutput represents output from login
|
||||||
type LoginOutput struct {
|
type LoginOutput struct {
|
||||||
Account *entities.Account
|
Account *entities.Account
|
||||||
AccessToken string
|
AccessToken string
|
||||||
RefreshToken string
|
RefreshToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginPort defines the input port for login
|
// LoginPort defines the input port for login
|
||||||
type LoginPort interface {
|
type LoginPort interface {
|
||||||
Execute(ctx context.Context, input LoginInput) (*LoginOutput, error)
|
Execute(ctx context.Context, input LoginInput) (*LoginOutput, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitiateRecoveryInput represents input for initiating recovery
|
// InitiateRecoveryInput represents input for initiating recovery
|
||||||
type InitiateRecoveryInput struct {
|
type InitiateRecoveryInput struct {
|
||||||
AccountID value_objects.AccountID
|
AccountID value_objects.AccountID
|
||||||
RecoveryType value_objects.RecoveryType
|
RecoveryType value_objects.RecoveryType
|
||||||
OldShareType *value_objects.ShareType
|
OldShareType *value_objects.ShareType
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitiateRecoveryOutput represents output from initiating recovery
|
// InitiateRecoveryOutput represents output from initiating recovery
|
||||||
type InitiateRecoveryOutput struct {
|
type InitiateRecoveryOutput struct {
|
||||||
RecoverySession *entities.RecoverySession
|
RecoverySession *entities.RecoverySession
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitiateRecoveryPort defines the input port for initiating recovery
|
// InitiateRecoveryPort defines the input port for initiating recovery
|
||||||
type InitiateRecoveryPort interface {
|
type InitiateRecoveryPort interface {
|
||||||
Execute(ctx context.Context, input InitiateRecoveryInput) (*InitiateRecoveryOutput, error)
|
Execute(ctx context.Context, input InitiateRecoveryInput) (*InitiateRecoveryOutput, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CompleteRecoveryInput represents input for completing recovery
|
// CompleteRecoveryInput represents input for completing recovery
|
||||||
type CompleteRecoveryInput struct {
|
type CompleteRecoveryInput struct {
|
||||||
RecoverySessionID string
|
RecoverySessionID string
|
||||||
NewPublicKey []byte
|
NewPublicKey []byte
|
||||||
NewKeygenSessionID uuid.UUID
|
NewKeygenSessionID uuid.UUID
|
||||||
NewShares []ShareInput
|
NewShares []ShareInput
|
||||||
}
|
}
|
||||||
|
|
||||||
// CompleteRecoveryOutput represents output from completing recovery
|
// CompleteRecoveryOutput represents output from completing recovery
|
||||||
type CompleteRecoveryOutput struct {
|
type CompleteRecoveryOutput struct {
|
||||||
Account *entities.Account
|
Account *entities.Account
|
||||||
}
|
}
|
||||||
|
|
||||||
// CompleteRecoveryPort defines the input port for completing recovery
|
// CompleteRecoveryPort defines the input port for completing recovery
|
||||||
type CompleteRecoveryPort interface {
|
type CompleteRecoveryPort interface {
|
||||||
Execute(ctx context.Context, input CompleteRecoveryInput) (*CompleteRecoveryOutput, error)
|
Execute(ctx context.Context, input CompleteRecoveryInput) (*CompleteRecoveryOutput, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccountInput represents input for updating an account
|
// UpdateAccountInput represents input for updating an account
|
||||||
type UpdateAccountInput struct {
|
type UpdateAccountInput struct {
|
||||||
AccountID value_objects.AccountID
|
AccountID value_objects.AccountID
|
||||||
Phone *string
|
Phone *string
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccountOutput represents output from updating an account
|
// UpdateAccountOutput represents output from updating an account
|
||||||
type UpdateAccountOutput struct {
|
type UpdateAccountOutput struct {
|
||||||
Account *entities.Account
|
Account *entities.Account
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccountPort defines the input port for updating accounts
|
// UpdateAccountPort defines the input port for updating accounts
|
||||||
type UpdateAccountPort interface {
|
type UpdateAccountPort interface {
|
||||||
Execute(ctx context.Context, input UpdateAccountInput) (*UpdateAccountOutput, error)
|
Execute(ctx context.Context, input UpdateAccountInput) (*UpdateAccountOutput, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeactivateShareInput represents input for deactivating a share
|
// DeactivateShareInput represents input for deactivating a share
|
||||||
type DeactivateShareInput struct {
|
type DeactivateShareInput struct {
|
||||||
AccountID value_objects.AccountID
|
AccountID value_objects.AccountID
|
||||||
ShareID string
|
ShareID string
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeactivateSharePort defines the input port for deactivating shares
|
// DeactivateSharePort defines the input port for deactivating shares
|
||||||
type DeactivateSharePort interface {
|
type DeactivateSharePort interface {
|
||||||
Execute(ctx context.Context, input DeactivateShareInput) error
|
Execute(ctx context.Context, input DeactivateShareInput) error
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,76 +1,76 @@
|
||||||
package ports
|
package ports
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
)
|
)
|
||||||
|
|
||||||
// EventType represents the type of account event
|
// EventType represents the type of account event
|
||||||
type EventType string
|
type EventType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
EventTypeAccountCreated EventType = "account.created"
|
EventTypeAccountCreated EventType = "account.created"
|
||||||
EventTypeAccountUpdated EventType = "account.updated"
|
EventTypeAccountUpdated EventType = "account.updated"
|
||||||
EventTypeAccountDeleted EventType = "account.deleted"
|
EventTypeAccountDeleted EventType = "account.deleted"
|
||||||
EventTypeAccountLogin EventType = "account.login"
|
EventTypeAccountLogin EventType = "account.login"
|
||||||
EventTypeRecoveryStarted EventType = "account.recovery.started"
|
EventTypeRecoveryStarted EventType = "account.recovery.started"
|
||||||
EventTypeRecoveryComplete EventType = "account.recovery.completed"
|
EventTypeRecoveryComplete EventType = "account.recovery.completed"
|
||||||
EventTypeShareDeactivated EventType = "account.share.deactivated"
|
EventTypeShareDeactivated EventType = "account.share.deactivated"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AccountEvent represents an account-related event
|
// AccountEvent represents an account-related event
|
||||||
type AccountEvent struct {
|
type AccountEvent struct {
|
||||||
Type EventType
|
Type EventType
|
||||||
AccountID string
|
AccountID string
|
||||||
Data map[string]interface{}
|
Data map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventPublisher defines the output port for publishing events
|
// EventPublisher defines the output port for publishing events
|
||||||
type EventPublisher interface {
|
type EventPublisher interface {
|
||||||
// Publish publishes an account event
|
// Publish publishes an account event
|
||||||
Publish(ctx context.Context, event AccountEvent) error
|
Publish(ctx context.Context, event AccountEvent) error
|
||||||
|
|
||||||
// Close closes the publisher
|
// Close closes the publisher
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenService defines the output port for token operations
|
// TokenService defines the output port for token operations
|
||||||
type TokenService interface {
|
type TokenService interface {
|
||||||
// GenerateAccessToken generates an access token for an account
|
// GenerateAccessToken generates an access token for an account
|
||||||
GenerateAccessToken(accountID, username string) (string, error)
|
GenerateAccessToken(accountID, username string) (string, error)
|
||||||
|
|
||||||
// GenerateRefreshToken generates a refresh token for an account
|
// GenerateRefreshToken generates a refresh token for an account
|
||||||
GenerateRefreshToken(accountID string) (string, error)
|
GenerateRefreshToken(accountID string) (string, error)
|
||||||
|
|
||||||
// ValidateAccessToken validates an access token
|
// ValidateAccessToken validates an access token
|
||||||
ValidateAccessToken(token string) (claims map[string]interface{}, err error)
|
ValidateAccessToken(token string) (claims map[string]interface{}, err error)
|
||||||
|
|
||||||
// ValidateRefreshToken validates a refresh token
|
// ValidateRefreshToken validates a refresh token
|
||||||
ValidateRefreshToken(token string) (accountID string, err error)
|
ValidateRefreshToken(token string) (accountID string, err error)
|
||||||
|
|
||||||
// RefreshAccessToken refreshes an access token using a refresh token
|
// RefreshAccessToken refreshes an access token using a refresh token
|
||||||
RefreshAccessToken(refreshToken string) (accessToken string, err error)
|
RefreshAccessToken(refreshToken string) (accessToken string, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionCoordinatorClient defines the output port for session coordinator communication
|
// SessionCoordinatorClient defines the output port for session coordinator communication
|
||||||
type SessionCoordinatorClient interface {
|
type SessionCoordinatorClient interface {
|
||||||
// GetSessionStatus gets the status of a keygen session
|
// GetSessionStatus gets the status of a keygen session
|
||||||
GetSessionStatus(ctx context.Context, sessionID string) (status string, err error)
|
GetSessionStatus(ctx context.Context, sessionID string) (status string, err error)
|
||||||
|
|
||||||
// IsSessionCompleted checks if a session is completed
|
// IsSessionCompleted checks if a session is completed
|
||||||
IsSessionCompleted(ctx context.Context, sessionID string) (bool, error)
|
IsSessionCompleted(ctx context.Context, sessionID string) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CacheService defines the output port for caching
|
// CacheService defines the output port for caching
|
||||||
type CacheService interface {
|
type CacheService interface {
|
||||||
// Set sets a value in the cache
|
// Set sets a value in the cache
|
||||||
Set(ctx context.Context, key string, value interface{}, ttlSeconds int) error
|
Set(ctx context.Context, key string, value interface{}, ttlSeconds int) error
|
||||||
|
|
||||||
// Get gets a value from the cache
|
// Get gets a value from the cache
|
||||||
Get(ctx context.Context, key string) (interface{}, error)
|
Get(ctx context.Context, key string) (interface{}, error)
|
||||||
|
|
||||||
// Delete deletes a value from the cache
|
// Delete deletes a value from the cache
|
||||||
Delete(ctx context.Context, key string) error
|
Delete(ctx context.Context, key string) error
|
||||||
|
|
||||||
// Exists checks if a key exists in the cache
|
// Exists checks if a key exists in the cache
|
||||||
Exists(ctx context.Context, key string) (bool, error)
|
Exists(ctx context.Context, key string) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,333 +1,333 @@
|
||||||
package use_cases
|
package use_cases
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/rwadurian/mpc-system/services/account/application/ports"
|
"github.com/rwadurian/mpc-system/services/account/application/ports"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/repositories"
|
"github.com/rwadurian/mpc-system/services/account/domain/repositories"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/services"
|
"github.com/rwadurian/mpc-system/services/account/domain/services"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateAccountUseCase handles account creation
|
// CreateAccountUseCase handles account creation
|
||||||
type CreateAccountUseCase struct {
|
type CreateAccountUseCase struct {
|
||||||
accountRepo repositories.AccountRepository
|
accountRepo repositories.AccountRepository
|
||||||
shareRepo repositories.AccountShareRepository
|
shareRepo repositories.AccountShareRepository
|
||||||
domainService *services.AccountDomainService
|
domainService *services.AccountDomainService
|
||||||
eventPublisher ports.EventPublisher
|
eventPublisher ports.EventPublisher
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCreateAccountUseCase creates a new CreateAccountUseCase
|
// NewCreateAccountUseCase creates a new CreateAccountUseCase
|
||||||
func NewCreateAccountUseCase(
|
func NewCreateAccountUseCase(
|
||||||
accountRepo repositories.AccountRepository,
|
accountRepo repositories.AccountRepository,
|
||||||
shareRepo repositories.AccountShareRepository,
|
shareRepo repositories.AccountShareRepository,
|
||||||
domainService *services.AccountDomainService,
|
domainService *services.AccountDomainService,
|
||||||
eventPublisher ports.EventPublisher,
|
eventPublisher ports.EventPublisher,
|
||||||
) *CreateAccountUseCase {
|
) *CreateAccountUseCase {
|
||||||
return &CreateAccountUseCase{
|
return &CreateAccountUseCase{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
shareRepo: shareRepo,
|
shareRepo: shareRepo,
|
||||||
domainService: domainService,
|
domainService: domainService,
|
||||||
eventPublisher: eventPublisher,
|
eventPublisher: eventPublisher,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute creates a new account
|
// Execute creates a new account
|
||||||
func (uc *CreateAccountUseCase) Execute(ctx context.Context, input ports.CreateAccountInput) (*ports.CreateAccountOutput, error) {
|
func (uc *CreateAccountUseCase) Execute(ctx context.Context, input ports.CreateAccountInput) (*ports.CreateAccountOutput, error) {
|
||||||
// Convert shares input
|
// Convert shares input
|
||||||
shares := make([]services.ShareInfo, len(input.Shares))
|
shares := make([]services.ShareInfo, len(input.Shares))
|
||||||
for i, s := range input.Shares {
|
for i, s := range input.Shares {
|
||||||
shares[i] = services.ShareInfo{
|
shares[i] = services.ShareInfo{
|
||||||
ShareType: s.ShareType,
|
ShareType: s.ShareType,
|
||||||
PartyID: s.PartyID,
|
PartyID: s.PartyID,
|
||||||
PartyIndex: s.PartyIndex,
|
PartyIndex: s.PartyIndex,
|
||||||
DeviceType: s.DeviceType,
|
DeviceType: s.DeviceType,
|
||||||
DeviceID: s.DeviceID,
|
DeviceID: s.DeviceID,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create account using domain service
|
// Create account using domain service
|
||||||
account, err := uc.domainService.CreateAccount(ctx, services.CreateAccountInput{
|
account, err := uc.domainService.CreateAccount(ctx, services.CreateAccountInput{
|
||||||
Username: input.Username,
|
Username: input.Username,
|
||||||
Email: input.Email,
|
Email: input.Email,
|
||||||
Phone: input.Phone,
|
Phone: input.Phone,
|
||||||
PublicKey: input.PublicKey,
|
PublicKey: input.PublicKey,
|
||||||
KeygenSessionID: input.KeygenSessionID,
|
KeygenSessionID: input.KeygenSessionID,
|
||||||
ThresholdN: input.ThresholdN,
|
ThresholdN: input.ThresholdN,
|
||||||
ThresholdT: input.ThresholdT,
|
ThresholdT: input.ThresholdT,
|
||||||
Shares: shares,
|
Shares: shares,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get created shares
|
// Get created shares
|
||||||
accountShares, err := uc.shareRepo.GetByAccountID(ctx, account.ID)
|
accountShares, err := uc.shareRepo.GetByAccountID(ctx, account.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish event
|
// Publish event
|
||||||
if uc.eventPublisher != nil {
|
if uc.eventPublisher != nil {
|
||||||
_ = uc.eventPublisher.Publish(ctx, ports.AccountEvent{
|
_ = uc.eventPublisher.Publish(ctx, ports.AccountEvent{
|
||||||
Type: ports.EventTypeAccountCreated,
|
Type: ports.EventTypeAccountCreated,
|
||||||
AccountID: account.ID.String(),
|
AccountID: account.ID.String(),
|
||||||
Data: map[string]interface{}{
|
Data: map[string]interface{}{
|
||||||
"username": account.Username,
|
"username": account.Username,
|
||||||
"email": account.Email,
|
"email": account.Email,
|
||||||
"thresholdN": account.ThresholdN,
|
"thresholdN": account.ThresholdN,
|
||||||
"thresholdT": account.ThresholdT,
|
"thresholdT": account.ThresholdT,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ports.CreateAccountOutput{
|
return &ports.CreateAccountOutput{
|
||||||
Account: account,
|
Account: account,
|
||||||
Shares: accountShares,
|
Shares: accountShares,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountUseCase handles getting account information
|
// GetAccountUseCase handles getting account information
|
||||||
type GetAccountUseCase struct {
|
type GetAccountUseCase struct {
|
||||||
accountRepo repositories.AccountRepository
|
accountRepo repositories.AccountRepository
|
||||||
shareRepo repositories.AccountShareRepository
|
shareRepo repositories.AccountShareRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGetAccountUseCase creates a new GetAccountUseCase
|
// NewGetAccountUseCase creates a new GetAccountUseCase
|
||||||
func NewGetAccountUseCase(
|
func NewGetAccountUseCase(
|
||||||
accountRepo repositories.AccountRepository,
|
accountRepo repositories.AccountRepository,
|
||||||
shareRepo repositories.AccountShareRepository,
|
shareRepo repositories.AccountShareRepository,
|
||||||
) *GetAccountUseCase {
|
) *GetAccountUseCase {
|
||||||
return &GetAccountUseCase{
|
return &GetAccountUseCase{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
shareRepo: shareRepo,
|
shareRepo: shareRepo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute gets account information
|
// Execute gets account information
|
||||||
func (uc *GetAccountUseCase) Execute(ctx context.Context, input ports.GetAccountInput) (*ports.GetAccountOutput, error) {
|
func (uc *GetAccountUseCase) Execute(ctx context.Context, input ports.GetAccountInput) (*ports.GetAccountOutput, error) {
|
||||||
var account *entities.Account
|
var account *entities.Account
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case input.AccountID != nil:
|
case input.AccountID != nil:
|
||||||
account, err = uc.accountRepo.GetByID(ctx, *input.AccountID)
|
account, err = uc.accountRepo.GetByID(ctx, *input.AccountID)
|
||||||
case input.Username != nil:
|
case input.Username != nil:
|
||||||
account, err = uc.accountRepo.GetByUsername(ctx, *input.Username)
|
account, err = uc.accountRepo.GetByUsername(ctx, *input.Username)
|
||||||
case input.Email != nil:
|
case input.Email != nil:
|
||||||
account, err = uc.accountRepo.GetByEmail(ctx, *input.Email)
|
account, err = uc.accountRepo.GetByEmail(ctx, *input.Email)
|
||||||
default:
|
default:
|
||||||
return nil, entities.ErrAccountNotFound
|
return nil, entities.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get shares
|
// Get shares
|
||||||
shares, err := uc.shareRepo.GetActiveByAccountID(ctx, account.ID)
|
shares, err := uc.shareRepo.GetActiveByAccountID(ctx, account.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ports.GetAccountOutput{
|
return &ports.GetAccountOutput{
|
||||||
Account: account,
|
Account: account,
|
||||||
Shares: shares,
|
Shares: shares,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccountUseCase handles account updates
|
// UpdateAccountUseCase handles account updates
|
||||||
type UpdateAccountUseCase struct {
|
type UpdateAccountUseCase struct {
|
||||||
accountRepo repositories.AccountRepository
|
accountRepo repositories.AccountRepository
|
||||||
eventPublisher ports.EventPublisher
|
eventPublisher ports.EventPublisher
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUpdateAccountUseCase creates a new UpdateAccountUseCase
|
// NewUpdateAccountUseCase creates a new UpdateAccountUseCase
|
||||||
func NewUpdateAccountUseCase(
|
func NewUpdateAccountUseCase(
|
||||||
accountRepo repositories.AccountRepository,
|
accountRepo repositories.AccountRepository,
|
||||||
eventPublisher ports.EventPublisher,
|
eventPublisher ports.EventPublisher,
|
||||||
) *UpdateAccountUseCase {
|
) *UpdateAccountUseCase {
|
||||||
return &UpdateAccountUseCase{
|
return &UpdateAccountUseCase{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
eventPublisher: eventPublisher,
|
eventPublisher: eventPublisher,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute updates an account
|
// Execute updates an account
|
||||||
func (uc *UpdateAccountUseCase) Execute(ctx context.Context, input ports.UpdateAccountInput) (*ports.UpdateAccountOutput, error) {
|
func (uc *UpdateAccountUseCase) Execute(ctx context.Context, input ports.UpdateAccountInput) (*ports.UpdateAccountOutput, error) {
|
||||||
account, err := uc.accountRepo.GetByID(ctx, input.AccountID)
|
account, err := uc.accountRepo.GetByID(ctx, input.AccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.Phone != nil {
|
if input.Phone != nil {
|
||||||
account.SetPhone(*input.Phone)
|
account.SetPhone(*input.Phone)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := uc.accountRepo.Update(ctx, account); err != nil {
|
if err := uc.accountRepo.Update(ctx, account); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish event
|
// Publish event
|
||||||
if uc.eventPublisher != nil {
|
if uc.eventPublisher != nil {
|
||||||
_ = uc.eventPublisher.Publish(ctx, ports.AccountEvent{
|
_ = uc.eventPublisher.Publish(ctx, ports.AccountEvent{
|
||||||
Type: ports.EventTypeAccountUpdated,
|
Type: ports.EventTypeAccountUpdated,
|
||||||
AccountID: account.ID.String(),
|
AccountID: account.ID.String(),
|
||||||
Data: map[string]interface{}{},
|
Data: map[string]interface{}{},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ports.UpdateAccountOutput{
|
return &ports.UpdateAccountOutput{
|
||||||
Account: account,
|
Account: account,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeactivateShareUseCase handles share deactivation
|
// DeactivateShareUseCase handles share deactivation
|
||||||
type DeactivateShareUseCase struct {
|
type DeactivateShareUseCase struct {
|
||||||
accountRepo repositories.AccountRepository
|
accountRepo repositories.AccountRepository
|
||||||
shareRepo repositories.AccountShareRepository
|
shareRepo repositories.AccountShareRepository
|
||||||
eventPublisher ports.EventPublisher
|
eventPublisher ports.EventPublisher
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDeactivateShareUseCase creates a new DeactivateShareUseCase
|
// NewDeactivateShareUseCase creates a new DeactivateShareUseCase
|
||||||
func NewDeactivateShareUseCase(
|
func NewDeactivateShareUseCase(
|
||||||
accountRepo repositories.AccountRepository,
|
accountRepo repositories.AccountRepository,
|
||||||
shareRepo repositories.AccountShareRepository,
|
shareRepo repositories.AccountShareRepository,
|
||||||
eventPublisher ports.EventPublisher,
|
eventPublisher ports.EventPublisher,
|
||||||
) *DeactivateShareUseCase {
|
) *DeactivateShareUseCase {
|
||||||
return &DeactivateShareUseCase{
|
return &DeactivateShareUseCase{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
shareRepo: shareRepo,
|
shareRepo: shareRepo,
|
||||||
eventPublisher: eventPublisher,
|
eventPublisher: eventPublisher,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute deactivates a share
|
// Execute deactivates a share
|
||||||
func (uc *DeactivateShareUseCase) Execute(ctx context.Context, input ports.DeactivateShareInput) error {
|
func (uc *DeactivateShareUseCase) Execute(ctx context.Context, input ports.DeactivateShareInput) error {
|
||||||
// Verify account exists
|
// Verify account exists
|
||||||
_, err := uc.accountRepo.GetByID(ctx, input.AccountID)
|
_, err := uc.accountRepo.GetByID(ctx, input.AccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get share
|
// Get share
|
||||||
share, err := uc.shareRepo.GetByID(ctx, input.ShareID)
|
share, err := uc.shareRepo.GetByID(ctx, input.ShareID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify share belongs to account
|
// Verify share belongs to account
|
||||||
if !share.AccountID.Equals(input.AccountID) {
|
if !share.AccountID.Equals(input.AccountID) {
|
||||||
return entities.ErrShareNotFound
|
return entities.ErrShareNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deactivate share
|
// Deactivate share
|
||||||
share.Deactivate()
|
share.Deactivate()
|
||||||
if err := uc.shareRepo.Update(ctx, share); err != nil {
|
if err := uc.shareRepo.Update(ctx, share); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish event
|
// Publish event
|
||||||
if uc.eventPublisher != nil {
|
if uc.eventPublisher != nil {
|
||||||
_ = uc.eventPublisher.Publish(ctx, ports.AccountEvent{
|
_ = uc.eventPublisher.Publish(ctx, ports.AccountEvent{
|
||||||
Type: ports.EventTypeShareDeactivated,
|
Type: ports.EventTypeShareDeactivated,
|
||||||
AccountID: input.AccountID.String(),
|
AccountID: input.AccountID.String(),
|
||||||
Data: map[string]interface{}{
|
Data: map[string]interface{}{
|
||||||
"shareId": input.ShareID,
|
"shareId": input.ShareID,
|
||||||
"shareType": share.ShareType.String(),
|
"shareType": share.ShareType.String(),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListAccountsInput represents input for listing accounts
|
// ListAccountsInput represents input for listing accounts
|
||||||
type ListAccountsInput struct {
|
type ListAccountsInput struct {
|
||||||
Offset int
|
Offset int
|
||||||
Limit int
|
Limit int
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListAccountsOutput represents output from listing accounts
|
// ListAccountsOutput represents output from listing accounts
|
||||||
type ListAccountsOutput struct {
|
type ListAccountsOutput struct {
|
||||||
Accounts []*entities.Account
|
Accounts []*entities.Account
|
||||||
Total int64
|
Total int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListAccountsUseCase handles listing accounts
|
// ListAccountsUseCase handles listing accounts
|
||||||
type ListAccountsUseCase struct {
|
type ListAccountsUseCase struct {
|
||||||
accountRepo repositories.AccountRepository
|
accountRepo repositories.AccountRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewListAccountsUseCase creates a new ListAccountsUseCase
|
// NewListAccountsUseCase creates a new ListAccountsUseCase
|
||||||
func NewListAccountsUseCase(accountRepo repositories.AccountRepository) *ListAccountsUseCase {
|
func NewListAccountsUseCase(accountRepo repositories.AccountRepository) *ListAccountsUseCase {
|
||||||
return &ListAccountsUseCase{
|
return &ListAccountsUseCase{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute lists accounts with pagination
|
// Execute lists accounts with pagination
|
||||||
func (uc *ListAccountsUseCase) Execute(ctx context.Context, input ListAccountsInput) (*ListAccountsOutput, error) {
|
func (uc *ListAccountsUseCase) Execute(ctx context.Context, input ListAccountsInput) (*ListAccountsOutput, error) {
|
||||||
if input.Limit <= 0 {
|
if input.Limit <= 0 {
|
||||||
input.Limit = 20
|
input.Limit = 20
|
||||||
}
|
}
|
||||||
if input.Limit > 100 {
|
if input.Limit > 100 {
|
||||||
input.Limit = 100
|
input.Limit = 100
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts, err := uc.accountRepo.List(ctx, input.Offset, input.Limit)
|
accounts, err := uc.accountRepo.List(ctx, input.Offset, input.Limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
total, err := uc.accountRepo.Count(ctx)
|
total, err := uc.accountRepo.Count(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ListAccountsOutput{
|
return &ListAccountsOutput{
|
||||||
Accounts: accounts,
|
Accounts: accounts,
|
||||||
Total: total,
|
Total: total,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountSharesUseCase handles getting account shares
|
// GetAccountSharesUseCase handles getting account shares
|
||||||
type GetAccountSharesUseCase struct {
|
type GetAccountSharesUseCase struct {
|
||||||
accountRepo repositories.AccountRepository
|
accountRepo repositories.AccountRepository
|
||||||
shareRepo repositories.AccountShareRepository
|
shareRepo repositories.AccountShareRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGetAccountSharesUseCase creates a new GetAccountSharesUseCase
|
// NewGetAccountSharesUseCase creates a new GetAccountSharesUseCase
|
||||||
func NewGetAccountSharesUseCase(
|
func NewGetAccountSharesUseCase(
|
||||||
accountRepo repositories.AccountRepository,
|
accountRepo repositories.AccountRepository,
|
||||||
shareRepo repositories.AccountShareRepository,
|
shareRepo repositories.AccountShareRepository,
|
||||||
) *GetAccountSharesUseCase {
|
) *GetAccountSharesUseCase {
|
||||||
return &GetAccountSharesUseCase{
|
return &GetAccountSharesUseCase{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
shareRepo: shareRepo,
|
shareRepo: shareRepo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountSharesOutput represents output from getting account shares
|
// GetAccountSharesOutput represents output from getting account shares
|
||||||
type GetAccountSharesOutput struct {
|
type GetAccountSharesOutput struct {
|
||||||
Shares []*entities.AccountShare
|
Shares []*entities.AccountShare
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute gets shares for an account
|
// Execute gets shares for an account
|
||||||
func (uc *GetAccountSharesUseCase) Execute(ctx context.Context, accountID value_objects.AccountID) (*GetAccountSharesOutput, error) {
|
func (uc *GetAccountSharesUseCase) Execute(ctx context.Context, accountID value_objects.AccountID) (*GetAccountSharesOutput, error) {
|
||||||
// Verify account exists
|
// Verify account exists
|
||||||
_, err := uc.accountRepo.GetByID(ctx, accountID)
|
_, err := uc.accountRepo.GetByID(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
shares, err := uc.shareRepo.GetByAccountID(ctx, accountID)
|
shares, err := uc.shareRepo.GetByAccountID(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &GetAccountSharesOutput{
|
return &GetAccountSharesOutput{
|
||||||
Shares: shares,
|
Shares: shares,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,253 +1,253 @@
|
||||||
package use_cases
|
package use_cases
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rwadurian/mpc-system/pkg/crypto"
|
"github.com/rwadurian/mpc-system/pkg/crypto"
|
||||||
"github.com/rwadurian/mpc-system/services/account/application/ports"
|
"github.com/rwadurian/mpc-system/services/account/application/ports"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/repositories"
|
"github.com/rwadurian/mpc-system/services/account/domain/repositories"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LoginError represents a login error
|
// LoginError represents a login error
|
||||||
type LoginError struct {
|
type LoginError struct {
|
||||||
Code string
|
Code string
|
||||||
Message string
|
Message string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *LoginError) Error() string {
|
func (e *LoginError) Error() string {
|
||||||
return e.Message
|
return e.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidCredentials = &LoginError{Code: "INVALID_CREDENTIALS", Message: "invalid username or signature"}
|
ErrInvalidCredentials = &LoginError{Code: "INVALID_CREDENTIALS", Message: "invalid username or signature"}
|
||||||
ErrAccountLocked = &LoginError{Code: "ACCOUNT_LOCKED", Message: "account is locked"}
|
ErrAccountLocked = &LoginError{Code: "ACCOUNT_LOCKED", Message: "account is locked"}
|
||||||
ErrAccountSuspended = &LoginError{Code: "ACCOUNT_SUSPENDED", Message: "account is suspended"}
|
ErrAccountSuspended = &LoginError{Code: "ACCOUNT_SUSPENDED", Message: "account is suspended"}
|
||||||
ErrSignatureInvalid = &LoginError{Code: "SIGNATURE_INVALID", Message: "signature verification failed"}
|
ErrSignatureInvalid = &LoginError{Code: "SIGNATURE_INVALID", Message: "signature verification failed"}
|
||||||
)
|
)
|
||||||
|
|
||||||
// LoginUseCase handles user login with MPC signature verification
|
// LoginUseCase handles user login with MPC signature verification
|
||||||
type LoginUseCase struct {
|
type LoginUseCase struct {
|
||||||
accountRepo repositories.AccountRepository
|
accountRepo repositories.AccountRepository
|
||||||
shareRepo repositories.AccountShareRepository
|
shareRepo repositories.AccountShareRepository
|
||||||
tokenService ports.TokenService
|
tokenService ports.TokenService
|
||||||
eventPublisher ports.EventPublisher
|
eventPublisher ports.EventPublisher
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewLoginUseCase creates a new LoginUseCase
|
// NewLoginUseCase creates a new LoginUseCase
|
||||||
func NewLoginUseCase(
|
func NewLoginUseCase(
|
||||||
accountRepo repositories.AccountRepository,
|
accountRepo repositories.AccountRepository,
|
||||||
shareRepo repositories.AccountShareRepository,
|
shareRepo repositories.AccountShareRepository,
|
||||||
tokenService ports.TokenService,
|
tokenService ports.TokenService,
|
||||||
eventPublisher ports.EventPublisher,
|
eventPublisher ports.EventPublisher,
|
||||||
) *LoginUseCase {
|
) *LoginUseCase {
|
||||||
return &LoginUseCase{
|
return &LoginUseCase{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
shareRepo: shareRepo,
|
shareRepo: shareRepo,
|
||||||
tokenService: tokenService,
|
tokenService: tokenService,
|
||||||
eventPublisher: eventPublisher,
|
eventPublisher: eventPublisher,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute performs login with signature verification
|
// Execute performs login with signature verification
|
||||||
func (uc *LoginUseCase) Execute(ctx context.Context, input ports.LoginInput) (*ports.LoginOutput, error) {
|
func (uc *LoginUseCase) Execute(ctx context.Context, input ports.LoginInput) (*ports.LoginOutput, error) {
|
||||||
// Get account by username
|
// Get account by username
|
||||||
account, err := uc.accountRepo.GetByUsername(ctx, input.Username)
|
account, err := uc.accountRepo.GetByUsername(ctx, input.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrInvalidCredentials
|
return nil, ErrInvalidCredentials
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check account status
|
// Check account status
|
||||||
if !account.CanLogin() {
|
if !account.CanLogin() {
|
||||||
switch account.Status.String() {
|
switch account.Status.String() {
|
||||||
case "locked":
|
case "locked":
|
||||||
return nil, ErrAccountLocked
|
return nil, ErrAccountLocked
|
||||||
case "suspended":
|
case "suspended":
|
||||||
return nil, ErrAccountSuspended
|
return nil, ErrAccountSuspended
|
||||||
default:
|
default:
|
||||||
return nil, entities.ErrAccountNotActive
|
return nil, entities.ErrAccountNotActive
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse public key
|
// Parse public key
|
||||||
pubKey, err := crypto.ParsePublicKey(account.PublicKey)
|
pubKey, err := crypto.ParsePublicKey(account.PublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrSignatureInvalid
|
return nil, ErrSignatureInvalid
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify signature (hash the challenge first, as SignMessage does)
|
// Verify signature (hash the challenge first, as SignMessage does)
|
||||||
challengeHash := crypto.HashMessage(input.Challenge)
|
challengeHash := crypto.HashMessage(input.Challenge)
|
||||||
if !crypto.VerifySignature(pubKey, challengeHash, input.Signature) {
|
if !crypto.VerifySignature(pubKey, challengeHash, input.Signature) {
|
||||||
return nil, ErrSignatureInvalid
|
return nil, ErrSignatureInvalid
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update last login
|
// Update last login
|
||||||
account.UpdateLastLogin()
|
account.UpdateLastLogin()
|
||||||
if err := uc.accountRepo.Update(ctx, account); err != nil {
|
if err := uc.accountRepo.Update(ctx, account); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate tokens
|
// Generate tokens
|
||||||
accessToken, err := uc.tokenService.GenerateAccessToken(account.ID.String(), account.Username)
|
accessToken, err := uc.tokenService.GenerateAccessToken(account.ID.String(), account.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshToken, err := uc.tokenService.GenerateRefreshToken(account.ID.String())
|
refreshToken, err := uc.tokenService.GenerateRefreshToken(account.ID.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish login event
|
// Publish login event
|
||||||
if uc.eventPublisher != nil {
|
if uc.eventPublisher != nil {
|
||||||
_ = uc.eventPublisher.Publish(ctx, ports.AccountEvent{
|
_ = uc.eventPublisher.Publish(ctx, ports.AccountEvent{
|
||||||
Type: ports.EventTypeAccountLogin,
|
Type: ports.EventTypeAccountLogin,
|
||||||
AccountID: account.ID.String(),
|
AccountID: account.ID.String(),
|
||||||
Data: map[string]interface{}{
|
Data: map[string]interface{}{
|
||||||
"username": account.Username,
|
"username": account.Username,
|
||||||
"timestamp": time.Now().UTC(),
|
"timestamp": time.Now().UTC(),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ports.LoginOutput{
|
return &ports.LoginOutput{
|
||||||
Account: account,
|
Account: account,
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshTokenInput represents input for refreshing tokens
|
// RefreshTokenInput represents input for refreshing tokens
|
||||||
type RefreshTokenInput struct {
|
type RefreshTokenInput struct {
|
||||||
RefreshToken string
|
RefreshToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshTokenOutput represents output from refreshing tokens
|
// RefreshTokenOutput represents output from refreshing tokens
|
||||||
type RefreshTokenOutput struct {
|
type RefreshTokenOutput struct {
|
||||||
AccessToken string
|
AccessToken string
|
||||||
RefreshToken string
|
RefreshToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshTokenUseCase handles token refresh
|
// RefreshTokenUseCase handles token refresh
|
||||||
type RefreshTokenUseCase struct {
|
type RefreshTokenUseCase struct {
|
||||||
accountRepo repositories.AccountRepository
|
accountRepo repositories.AccountRepository
|
||||||
tokenService ports.TokenService
|
tokenService ports.TokenService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRefreshTokenUseCase creates a new RefreshTokenUseCase
|
// NewRefreshTokenUseCase creates a new RefreshTokenUseCase
|
||||||
func NewRefreshTokenUseCase(
|
func NewRefreshTokenUseCase(
|
||||||
accountRepo repositories.AccountRepository,
|
accountRepo repositories.AccountRepository,
|
||||||
tokenService ports.TokenService,
|
tokenService ports.TokenService,
|
||||||
) *RefreshTokenUseCase {
|
) *RefreshTokenUseCase {
|
||||||
return &RefreshTokenUseCase{
|
return &RefreshTokenUseCase{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
tokenService: tokenService,
|
tokenService: tokenService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute refreshes the access token
|
// Execute refreshes the access token
|
||||||
func (uc *RefreshTokenUseCase) Execute(ctx context.Context, input RefreshTokenInput) (*RefreshTokenOutput, error) {
|
func (uc *RefreshTokenUseCase) Execute(ctx context.Context, input RefreshTokenInput) (*RefreshTokenOutput, error) {
|
||||||
// Validate refresh token and get account ID
|
// Validate refresh token and get account ID
|
||||||
accountIDStr, err := uc.tokenService.ValidateRefreshToken(input.RefreshToken)
|
accountIDStr, err := uc.tokenService.ValidateRefreshToken(input.RefreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get account to verify it still exists and is active
|
// Get account to verify it still exists and is active
|
||||||
accountID, err := parseAccountID(accountIDStr)
|
accountID, err := parseAccountID(accountIDStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := uc.accountRepo.GetByID(ctx, accountID)
|
account, err := uc.accountRepo.GetByID(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !account.CanLogin() {
|
if !account.CanLogin() {
|
||||||
return nil, entities.ErrAccountNotActive
|
return nil, entities.ErrAccountNotActive
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate new access token
|
// Generate new access token
|
||||||
accessToken, err := uc.tokenService.GenerateAccessToken(account.ID.String(), account.Username)
|
accessToken, err := uc.tokenService.GenerateAccessToken(account.ID.String(), account.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate new refresh token
|
// Generate new refresh token
|
||||||
refreshToken, err := uc.tokenService.GenerateRefreshToken(account.ID.String())
|
refreshToken, err := uc.tokenService.GenerateRefreshToken(account.ID.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &RefreshTokenOutput{
|
return &RefreshTokenOutput{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateChallengeUseCase handles challenge generation for login
|
// GenerateChallengeUseCase handles challenge generation for login
|
||||||
type GenerateChallengeUseCase struct {
|
type GenerateChallengeUseCase struct {
|
||||||
cacheService ports.CacheService
|
cacheService ports.CacheService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGenerateChallengeUseCase creates a new GenerateChallengeUseCase
|
// NewGenerateChallengeUseCase creates a new GenerateChallengeUseCase
|
||||||
func NewGenerateChallengeUseCase(cacheService ports.CacheService) *GenerateChallengeUseCase {
|
func NewGenerateChallengeUseCase(cacheService ports.CacheService) *GenerateChallengeUseCase {
|
||||||
return &GenerateChallengeUseCase{
|
return &GenerateChallengeUseCase{
|
||||||
cacheService: cacheService,
|
cacheService: cacheService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateChallengeInput represents input for generating a challenge
|
// GenerateChallengeInput represents input for generating a challenge
|
||||||
type GenerateChallengeInput struct {
|
type GenerateChallengeInput struct {
|
||||||
Username string
|
Username string
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateChallengeOutput represents output from generating a challenge
|
// GenerateChallengeOutput represents output from generating a challenge
|
||||||
type GenerateChallengeOutput struct {
|
type GenerateChallengeOutput struct {
|
||||||
Challenge []byte
|
Challenge []byte
|
||||||
ChallengeID string
|
ChallengeID string
|
||||||
ExpiresAt time.Time
|
ExpiresAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute generates a challenge for login
|
// Execute generates a challenge for login
|
||||||
func (uc *GenerateChallengeUseCase) Execute(ctx context.Context, input GenerateChallengeInput) (*GenerateChallengeOutput, error) {
|
func (uc *GenerateChallengeUseCase) Execute(ctx context.Context, input GenerateChallengeInput) (*GenerateChallengeOutput, error) {
|
||||||
// Generate random challenge
|
// Generate random challenge
|
||||||
challenge, err := crypto.GenerateRandomBytes(32)
|
challenge, err := crypto.GenerateRandomBytes(32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate challenge ID
|
// Generate challenge ID
|
||||||
challengeID, err := crypto.GenerateRandomBytes(16)
|
challengeID, err := crypto.GenerateRandomBytes(16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
challengeIDStr := hex.EncodeToString(challengeID)
|
challengeIDStr := hex.EncodeToString(challengeID)
|
||||||
expiresAt := time.Now().UTC().Add(5 * time.Minute)
|
expiresAt := time.Now().UTC().Add(5 * time.Minute)
|
||||||
|
|
||||||
// Store challenge in cache
|
// Store challenge in cache
|
||||||
cacheKey := "login_challenge:" + challengeIDStr
|
cacheKey := "login_challenge:" + challengeIDStr
|
||||||
if uc.cacheService != nil {
|
if uc.cacheService != nil {
|
||||||
_ = uc.cacheService.Set(ctx, cacheKey, map[string]interface{}{
|
_ = uc.cacheService.Set(ctx, cacheKey, map[string]interface{}{
|
||||||
"username": input.Username,
|
"username": input.Username,
|
||||||
"challenge": hex.EncodeToString(challenge),
|
"challenge": hex.EncodeToString(challenge),
|
||||||
"expiresAt": expiresAt,
|
"expiresAt": expiresAt,
|
||||||
}, 300) // 5 minutes TTL
|
}, 300) // 5 minutes TTL
|
||||||
}
|
}
|
||||||
|
|
||||||
return &GenerateChallengeOutput{
|
return &GenerateChallengeOutput{
|
||||||
Challenge: challenge,
|
Challenge: challenge,
|
||||||
ChallengeID: challengeIDStr,
|
ChallengeID: challengeIDStr,
|
||||||
ExpiresAt: expiresAt,
|
ExpiresAt: expiresAt,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// helper function to parse account ID
|
// helper function to parse account ID
|
||||||
func parseAccountID(s string) (value_objects.AccountID, error) {
|
func parseAccountID(s string) (value_objects.AccountID, error) {
|
||||||
return value_objects.AccountIDFromString(s)
|
return value_objects.AccountIDFromString(s)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,244 +1,244 @@
|
||||||
package use_cases
|
package use_cases
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/rwadurian/mpc-system/services/account/application/ports"
|
"github.com/rwadurian/mpc-system/services/account/application/ports"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/repositories"
|
"github.com/rwadurian/mpc-system/services/account/domain/repositories"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/services"
|
"github.com/rwadurian/mpc-system/services/account/domain/services"
|
||||||
)
|
)
|
||||||
|
|
||||||
// InitiateRecoveryUseCase handles initiating account recovery
|
// InitiateRecoveryUseCase handles initiating account recovery
|
||||||
type InitiateRecoveryUseCase struct {
|
type InitiateRecoveryUseCase struct {
|
||||||
accountRepo repositories.AccountRepository
|
accountRepo repositories.AccountRepository
|
||||||
recoveryRepo repositories.RecoverySessionRepository
|
recoveryRepo repositories.RecoverySessionRepository
|
||||||
domainService *services.AccountDomainService
|
domainService *services.AccountDomainService
|
||||||
eventPublisher ports.EventPublisher
|
eventPublisher ports.EventPublisher
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewInitiateRecoveryUseCase creates a new InitiateRecoveryUseCase
|
// NewInitiateRecoveryUseCase creates a new InitiateRecoveryUseCase
|
||||||
func NewInitiateRecoveryUseCase(
|
func NewInitiateRecoveryUseCase(
|
||||||
accountRepo repositories.AccountRepository,
|
accountRepo repositories.AccountRepository,
|
||||||
recoveryRepo repositories.RecoverySessionRepository,
|
recoveryRepo repositories.RecoverySessionRepository,
|
||||||
domainService *services.AccountDomainService,
|
domainService *services.AccountDomainService,
|
||||||
eventPublisher ports.EventPublisher,
|
eventPublisher ports.EventPublisher,
|
||||||
) *InitiateRecoveryUseCase {
|
) *InitiateRecoveryUseCase {
|
||||||
return &InitiateRecoveryUseCase{
|
return &InitiateRecoveryUseCase{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
recoveryRepo: recoveryRepo,
|
recoveryRepo: recoveryRepo,
|
||||||
domainService: domainService,
|
domainService: domainService,
|
||||||
eventPublisher: eventPublisher,
|
eventPublisher: eventPublisher,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute initiates account recovery
|
// Execute initiates account recovery
|
||||||
func (uc *InitiateRecoveryUseCase) Execute(ctx context.Context, input ports.InitiateRecoveryInput) (*ports.InitiateRecoveryOutput, error) {
|
func (uc *InitiateRecoveryUseCase) Execute(ctx context.Context, input ports.InitiateRecoveryInput) (*ports.InitiateRecoveryOutput, error) {
|
||||||
// Check if there's already an active recovery session
|
// Check if there's already an active recovery session
|
||||||
existingRecovery, err := uc.recoveryRepo.GetActiveByAccountID(ctx, input.AccountID)
|
existingRecovery, err := uc.recoveryRepo.GetActiveByAccountID(ctx, input.AccountID)
|
||||||
if err == nil && existingRecovery != nil {
|
if err == nil && existingRecovery != nil {
|
||||||
return nil, &entities.AccountError{
|
return nil, &entities.AccountError{
|
||||||
Code: "RECOVERY_ALREADY_IN_PROGRESS",
|
Code: "RECOVERY_ALREADY_IN_PROGRESS",
|
||||||
Message: "there is already an active recovery session for this account",
|
Message: "there is already an active recovery session for this account",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initiate recovery using domain service
|
// Initiate recovery using domain service
|
||||||
recoverySession, err := uc.domainService.InitiateRecovery(ctx, input.AccountID, input.RecoveryType, input.OldShareType)
|
recoverySession, err := uc.domainService.InitiateRecovery(ctx, input.AccountID, input.RecoveryType, input.OldShareType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish event
|
// Publish event
|
||||||
if uc.eventPublisher != nil {
|
if uc.eventPublisher != nil {
|
||||||
_ = uc.eventPublisher.Publish(ctx, ports.AccountEvent{
|
_ = uc.eventPublisher.Publish(ctx, ports.AccountEvent{
|
||||||
Type: ports.EventTypeRecoveryStarted,
|
Type: ports.EventTypeRecoveryStarted,
|
||||||
AccountID: input.AccountID.String(),
|
AccountID: input.AccountID.String(),
|
||||||
Data: map[string]interface{}{
|
Data: map[string]interface{}{
|
||||||
"recoverySessionId": recoverySession.ID.String(),
|
"recoverySessionId": recoverySession.ID.String(),
|
||||||
"recoveryType": input.RecoveryType.String(),
|
"recoveryType": input.RecoveryType.String(),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ports.InitiateRecoveryOutput{
|
return &ports.InitiateRecoveryOutput{
|
||||||
RecoverySession: recoverySession,
|
RecoverySession: recoverySession,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CompleteRecoveryUseCase handles completing account recovery
|
// CompleteRecoveryUseCase handles completing account recovery
|
||||||
type CompleteRecoveryUseCase struct {
|
type CompleteRecoveryUseCase struct {
|
||||||
accountRepo repositories.AccountRepository
|
accountRepo repositories.AccountRepository
|
||||||
shareRepo repositories.AccountShareRepository
|
shareRepo repositories.AccountShareRepository
|
||||||
recoveryRepo repositories.RecoverySessionRepository
|
recoveryRepo repositories.RecoverySessionRepository
|
||||||
domainService *services.AccountDomainService
|
domainService *services.AccountDomainService
|
||||||
eventPublisher ports.EventPublisher
|
eventPublisher ports.EventPublisher
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCompleteRecoveryUseCase creates a new CompleteRecoveryUseCase
|
// NewCompleteRecoveryUseCase creates a new CompleteRecoveryUseCase
|
||||||
func NewCompleteRecoveryUseCase(
|
func NewCompleteRecoveryUseCase(
|
||||||
accountRepo repositories.AccountRepository,
|
accountRepo repositories.AccountRepository,
|
||||||
shareRepo repositories.AccountShareRepository,
|
shareRepo repositories.AccountShareRepository,
|
||||||
recoveryRepo repositories.RecoverySessionRepository,
|
recoveryRepo repositories.RecoverySessionRepository,
|
||||||
domainService *services.AccountDomainService,
|
domainService *services.AccountDomainService,
|
||||||
eventPublisher ports.EventPublisher,
|
eventPublisher ports.EventPublisher,
|
||||||
) *CompleteRecoveryUseCase {
|
) *CompleteRecoveryUseCase {
|
||||||
return &CompleteRecoveryUseCase{
|
return &CompleteRecoveryUseCase{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
shareRepo: shareRepo,
|
shareRepo: shareRepo,
|
||||||
recoveryRepo: recoveryRepo,
|
recoveryRepo: recoveryRepo,
|
||||||
domainService: domainService,
|
domainService: domainService,
|
||||||
eventPublisher: eventPublisher,
|
eventPublisher: eventPublisher,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute completes account recovery
|
// Execute completes account recovery
|
||||||
func (uc *CompleteRecoveryUseCase) Execute(ctx context.Context, input ports.CompleteRecoveryInput) (*ports.CompleteRecoveryOutput, error) {
|
func (uc *CompleteRecoveryUseCase) Execute(ctx context.Context, input ports.CompleteRecoveryInput) (*ports.CompleteRecoveryOutput, error) {
|
||||||
// Convert shares input
|
// Convert shares input
|
||||||
newShares := make([]services.ShareInfo, len(input.NewShares))
|
newShares := make([]services.ShareInfo, len(input.NewShares))
|
||||||
for i, s := range input.NewShares {
|
for i, s := range input.NewShares {
|
||||||
newShares[i] = services.ShareInfo{
|
newShares[i] = services.ShareInfo{
|
||||||
ShareType: s.ShareType,
|
ShareType: s.ShareType,
|
||||||
PartyID: s.PartyID,
|
PartyID: s.PartyID,
|
||||||
PartyIndex: s.PartyIndex,
|
PartyIndex: s.PartyIndex,
|
||||||
DeviceType: s.DeviceType,
|
DeviceType: s.DeviceType,
|
||||||
DeviceID: s.DeviceID,
|
DeviceID: s.DeviceID,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Complete recovery using domain service
|
// Complete recovery using domain service
|
||||||
err := uc.domainService.CompleteRecovery(
|
err := uc.domainService.CompleteRecovery(
|
||||||
ctx,
|
ctx,
|
||||||
input.RecoverySessionID,
|
input.RecoverySessionID,
|
||||||
input.NewPublicKey,
|
input.NewPublicKey,
|
||||||
input.NewKeygenSessionID,
|
input.NewKeygenSessionID,
|
||||||
newShares,
|
newShares,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get recovery session to get account ID
|
// Get recovery session to get account ID
|
||||||
recovery, err := uc.recoveryRepo.GetByID(ctx, input.RecoverySessionID)
|
recovery, err := uc.recoveryRepo.GetByID(ctx, input.RecoverySessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get updated account
|
// Get updated account
|
||||||
account, err := uc.accountRepo.GetByID(ctx, recovery.AccountID)
|
account, err := uc.accountRepo.GetByID(ctx, recovery.AccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish event
|
// Publish event
|
||||||
if uc.eventPublisher != nil {
|
if uc.eventPublisher != nil {
|
||||||
_ = uc.eventPublisher.Publish(ctx, ports.AccountEvent{
|
_ = uc.eventPublisher.Publish(ctx, ports.AccountEvent{
|
||||||
Type: ports.EventTypeRecoveryComplete,
|
Type: ports.EventTypeRecoveryComplete,
|
||||||
AccountID: account.ID.String(),
|
AccountID: account.ID.String(),
|
||||||
Data: map[string]interface{}{
|
Data: map[string]interface{}{
|
||||||
"recoverySessionId": input.RecoverySessionID,
|
"recoverySessionId": input.RecoverySessionID,
|
||||||
"newKeygenSessionId": input.NewKeygenSessionID.String(),
|
"newKeygenSessionId": input.NewKeygenSessionID.String(),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ports.CompleteRecoveryOutput{
|
return &ports.CompleteRecoveryOutput{
|
||||||
Account: account,
|
Account: account,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRecoveryStatusInput represents input for getting recovery status
|
// GetRecoveryStatusInput represents input for getting recovery status
|
||||||
type GetRecoveryStatusInput struct {
|
type GetRecoveryStatusInput struct {
|
||||||
RecoverySessionID string
|
RecoverySessionID string
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRecoveryStatusOutput represents output from getting recovery status
|
// GetRecoveryStatusOutput represents output from getting recovery status
|
||||||
type GetRecoveryStatusOutput struct {
|
type GetRecoveryStatusOutput struct {
|
||||||
RecoverySession *entities.RecoverySession
|
RecoverySession *entities.RecoverySession
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRecoveryStatusUseCase handles getting recovery session status
|
// GetRecoveryStatusUseCase handles getting recovery session status
|
||||||
type GetRecoveryStatusUseCase struct {
|
type GetRecoveryStatusUseCase struct {
|
||||||
recoveryRepo repositories.RecoverySessionRepository
|
recoveryRepo repositories.RecoverySessionRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGetRecoveryStatusUseCase creates a new GetRecoveryStatusUseCase
|
// NewGetRecoveryStatusUseCase creates a new GetRecoveryStatusUseCase
|
||||||
func NewGetRecoveryStatusUseCase(recoveryRepo repositories.RecoverySessionRepository) *GetRecoveryStatusUseCase {
|
func NewGetRecoveryStatusUseCase(recoveryRepo repositories.RecoverySessionRepository) *GetRecoveryStatusUseCase {
|
||||||
return &GetRecoveryStatusUseCase{
|
return &GetRecoveryStatusUseCase{
|
||||||
recoveryRepo: recoveryRepo,
|
recoveryRepo: recoveryRepo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute gets recovery session status
|
// Execute gets recovery session status
|
||||||
func (uc *GetRecoveryStatusUseCase) Execute(ctx context.Context, input GetRecoveryStatusInput) (*GetRecoveryStatusOutput, error) {
|
func (uc *GetRecoveryStatusUseCase) Execute(ctx context.Context, input GetRecoveryStatusInput) (*GetRecoveryStatusOutput, error) {
|
||||||
recovery, err := uc.recoveryRepo.GetByID(ctx, input.RecoverySessionID)
|
recovery, err := uc.recoveryRepo.GetByID(ctx, input.RecoverySessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &GetRecoveryStatusOutput{
|
return &GetRecoveryStatusOutput{
|
||||||
RecoverySession: recovery,
|
RecoverySession: recovery,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CancelRecoveryInput represents input for canceling recovery
|
// CancelRecoveryInput represents input for canceling recovery
|
||||||
type CancelRecoveryInput struct {
|
type CancelRecoveryInput struct {
|
||||||
RecoverySessionID string
|
RecoverySessionID string
|
||||||
}
|
}
|
||||||
|
|
||||||
// CancelRecoveryUseCase handles canceling recovery
|
// CancelRecoveryUseCase handles canceling recovery
|
||||||
type CancelRecoveryUseCase struct {
|
type CancelRecoveryUseCase struct {
|
||||||
accountRepo repositories.AccountRepository
|
accountRepo repositories.AccountRepository
|
||||||
recoveryRepo repositories.RecoverySessionRepository
|
recoveryRepo repositories.RecoverySessionRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCancelRecoveryUseCase creates a new CancelRecoveryUseCase
|
// NewCancelRecoveryUseCase creates a new CancelRecoveryUseCase
|
||||||
func NewCancelRecoveryUseCase(
|
func NewCancelRecoveryUseCase(
|
||||||
accountRepo repositories.AccountRepository,
|
accountRepo repositories.AccountRepository,
|
||||||
recoveryRepo repositories.RecoverySessionRepository,
|
recoveryRepo repositories.RecoverySessionRepository,
|
||||||
) *CancelRecoveryUseCase {
|
) *CancelRecoveryUseCase {
|
||||||
return &CancelRecoveryUseCase{
|
return &CancelRecoveryUseCase{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
recoveryRepo: recoveryRepo,
|
recoveryRepo: recoveryRepo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute cancels a recovery session
|
// Execute cancels a recovery session
|
||||||
func (uc *CancelRecoveryUseCase) Execute(ctx context.Context, input CancelRecoveryInput) error {
|
func (uc *CancelRecoveryUseCase) Execute(ctx context.Context, input CancelRecoveryInput) error {
|
||||||
// Get recovery session
|
// Get recovery session
|
||||||
recovery, err := uc.recoveryRepo.GetByID(ctx, input.RecoverySessionID)
|
recovery, err := uc.recoveryRepo.GetByID(ctx, input.RecoverySessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if recovery can be canceled
|
// Check if recovery can be canceled
|
||||||
if recovery.IsCompleted() {
|
if recovery.IsCompleted() {
|
||||||
return &entities.AccountError{
|
return &entities.AccountError{
|
||||||
Code: "RECOVERY_CANNOT_CANCEL",
|
Code: "RECOVERY_CANNOT_CANCEL",
|
||||||
Message: "cannot cancel completed recovery",
|
Message: "cannot cancel completed recovery",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark recovery as failed
|
// Mark recovery as failed
|
||||||
if err := recovery.Fail(); err != nil {
|
if err := recovery.Fail(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update recovery session
|
// Update recovery session
|
||||||
if err := uc.recoveryRepo.Update(ctx, recovery); err != nil {
|
if err := uc.recoveryRepo.Update(ctx, recovery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reactivate account
|
// Reactivate account
|
||||||
account, err := uc.accountRepo.GetByID(ctx, recovery.AccountID)
|
account, err := uc.accountRepo.GetByID(ctx, recovery.AccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
account.Activate()
|
account.Activate()
|
||||||
if err := uc.accountRepo.Update(ctx, account); err != nil {
|
if err := uc.accountRepo.Update(ctx, account); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -174,50 +174,191 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func initDatabase(cfg config.DatabaseConfig) (*sql.DB, error) {
|
func initDatabase(cfg config.DatabaseConfig) (*sql.DB, error) {
|
||||||
db, err := sql.Open("postgres", cfg.DSN())
|
const maxRetries = 10
|
||||||
if err != nil {
|
const retryDelay = 2 * time.Second
|
||||||
return nil, err
|
|
||||||
|
var db *sql.DB
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for i := 0; i < maxRetries; i++ {
|
||||||
|
db, err = sql.Open("postgres", cfg.DSN())
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to open database connection, retrying...",
|
||||||
|
zap.Int("attempt", i+1),
|
||||||
|
zap.Int("max_retries", maxRetries),
|
||||||
|
zap.Error(err))
|
||||||
|
time.Sleep(retryDelay * time.Duration(i+1))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||||
|
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||||
|
db.SetConnMaxLifetime(cfg.ConnMaxLife)
|
||||||
|
|
||||||
|
// Test connection with Ping
|
||||||
|
if err = db.Ping(); err != nil {
|
||||||
|
logger.Warn("Failed to ping database, retrying...",
|
||||||
|
zap.Int("attempt", i+1),
|
||||||
|
zap.Int("max_retries", maxRetries),
|
||||||
|
zap.Error(err))
|
||||||
|
db.Close()
|
||||||
|
time.Sleep(retryDelay * time.Duration(i+1))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify database is actually usable with a simple query
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
var result int
|
||||||
|
err = db.QueryRowContext(ctx, "SELECT 1").Scan(&result)
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Database ping succeeded but query failed, retrying...",
|
||||||
|
zap.Int("attempt", i+1),
|
||||||
|
zap.Int("max_retries", maxRetries),
|
||||||
|
zap.Error(err))
|
||||||
|
db.Close()
|
||||||
|
time.Sleep(retryDelay * time.Duration(i+1))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Connected to PostgreSQL and verified connectivity",
|
||||||
|
zap.Int("attempt", i+1))
|
||||||
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
return nil, fmt.Errorf("failed to connect to database after %d retries: %w", maxRetries, err)
|
||||||
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
|
||||||
db.SetConnMaxLifetime(cfg.ConnMaxLife)
|
|
||||||
|
|
||||||
// Test connection
|
|
||||||
if err := db.Ping(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Connected to PostgreSQL")
|
|
||||||
return db, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func initRedis(cfg config.RedisConfig) *redis.Client {
|
func initRedis(cfg config.RedisConfig) *redis.Client {
|
||||||
|
const maxRetries = 10
|
||||||
|
const retryDelay = 2 * time.Second
|
||||||
|
|
||||||
client := redis.NewClient(&redis.Options{
|
client := redis.NewClient(&redis.Options{
|
||||||
Addr: cfg.Addr(),
|
Addr: cfg.Addr(),
|
||||||
Password: cfg.Password,
|
Password: cfg.Password,
|
||||||
DB: cfg.DB,
|
DB: cfg.DB,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Test connection
|
// Test connection with retry
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
if err := client.Ping(ctx).Err(); err != nil {
|
for i := 0; i < maxRetries; i++ {
|
||||||
logger.Warn("Redis connection failed, continuing without cache", zap.Error(err))
|
if err := client.Ping(ctx).Err(); err != nil {
|
||||||
} else {
|
logger.Warn("Redis connection failed, retrying...",
|
||||||
|
zap.Int("attempt", i+1),
|
||||||
|
zap.Int("max_retries", maxRetries),
|
||||||
|
zap.Error(err))
|
||||||
|
time.Sleep(retryDelay * time.Duration(i+1))
|
||||||
|
continue
|
||||||
|
}
|
||||||
logger.Info("Connected to Redis")
|
logger.Info("Connected to Redis")
|
||||||
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Warn("Redis connection failed after retries, continuing without cache")
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
func initRabbitMQ(cfg config.RabbitMQConfig) (*amqp.Connection, error) {
|
func initRabbitMQ(cfg config.RabbitMQConfig) (*amqp.Connection, error) {
|
||||||
conn, err := amqp.Dial(cfg.URL())
|
const maxRetries = 10
|
||||||
if err != nil {
|
const retryDelay = 2 * time.Second
|
||||||
return nil, err
|
|
||||||
|
var conn *amqp.Connection
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for i := 0; i < maxRetries; i++ {
|
||||||
|
// Attempt to dial RabbitMQ
|
||||||
|
conn, err = amqp.Dial(cfg.URL())
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to dial RabbitMQ, retrying...",
|
||||||
|
zap.Int("attempt", i+1),
|
||||||
|
zap.Int("max_retries", maxRetries),
|
||||||
|
zap.String("url", maskPassword(cfg.URL())),
|
||||||
|
zap.Error(err))
|
||||||
|
time.Sleep(retryDelay * time.Duration(i+1))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify connection is actually usable by opening a channel
|
||||||
|
ch, err := conn.Channel()
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("RabbitMQ connection established but channel creation failed, retrying...",
|
||||||
|
zap.Int("attempt", i+1),
|
||||||
|
zap.Int("max_retries", maxRetries),
|
||||||
|
zap.Error(err))
|
||||||
|
conn.Close()
|
||||||
|
time.Sleep(retryDelay * time.Duration(i+1))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the channel with a simple operation (declare a test exchange)
|
||||||
|
err = ch.ExchangeDeclare(
|
||||||
|
"mpc.health.check", // name
|
||||||
|
"fanout", // type
|
||||||
|
false, // durable
|
||||||
|
true, // auto-deleted
|
||||||
|
false, // internal
|
||||||
|
false, // no-wait
|
||||||
|
nil, // arguments
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("RabbitMQ channel created but exchange declaration failed, retrying...",
|
||||||
|
zap.Int("attempt", i+1),
|
||||||
|
zap.Int("max_retries", maxRetries),
|
||||||
|
zap.Error(err))
|
||||||
|
ch.Close()
|
||||||
|
conn.Close()
|
||||||
|
time.Sleep(retryDelay * time.Duration(i+1))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up test exchange
|
||||||
|
ch.ExchangeDelete("mpc.health.check", false, false)
|
||||||
|
ch.Close()
|
||||||
|
|
||||||
|
// Setup connection close notification
|
||||||
|
closeChan := make(chan *amqp.Error, 1)
|
||||||
|
conn.NotifyClose(closeChan)
|
||||||
|
go func() {
|
||||||
|
err := <-closeChan
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("RabbitMQ connection closed unexpectedly", zap.Error(err))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
logger.Info("Connected to RabbitMQ and verified connectivity",
|
||||||
|
zap.Int("attempt", i+1))
|
||||||
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Connected to RabbitMQ")
|
return nil, fmt.Errorf("failed to connect to RabbitMQ after %d retries: %w", maxRetries, err)
|
||||||
return conn, nil
|
}
|
||||||
|
|
||||||
|
// maskPassword masks the password in the RabbitMQ URL for logging
|
||||||
|
func maskPassword(url string) string {
|
||||||
|
// Simple masking: amqp://user:password@host:port -> amqp://user:****@host:port
|
||||||
|
start := 0
|
||||||
|
for i := 0; i < len(url); i++ {
|
||||||
|
if url[i] == ':' && i > 0 && url[i-1] != '/' {
|
||||||
|
start = i + 1
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if start == 0 {
|
||||||
|
return url
|
||||||
|
}
|
||||||
|
|
||||||
|
end := start
|
||||||
|
for i := start; i < len(url); i++ {
|
||||||
|
if url[i] == '@' {
|
||||||
|
end = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if end == start {
|
||||||
|
return url
|
||||||
|
}
|
||||||
|
|
||||||
|
return url[:start] + "****" + url[end:]
|
||||||
}
|
}
|
||||||
|
|
||||||
func startHTTPServer(
|
func startHTTPServer(
|
||||||
|
|
|
||||||
|
|
@ -1,160 +1,160 @@
|
||||||
package entities
|
package entities
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Account represents a user account with MPC-based authentication
|
// Account represents a user account with MPC-based authentication
|
||||||
type Account struct {
|
type Account struct {
|
||||||
ID value_objects.AccountID
|
ID value_objects.AccountID
|
||||||
Username string // Required: auto-generated by identity-service
|
Username string // Required: auto-generated by identity-service
|
||||||
Email *string // Optional: for anonymous accounts
|
Email *string // Optional: for anonymous accounts
|
||||||
Phone *string
|
Phone *string
|
||||||
PublicKey []byte // MPC group public key
|
PublicKey []byte // MPC group public key
|
||||||
KeygenSessionID uuid.UUID
|
KeygenSessionID uuid.UUID
|
||||||
ThresholdN int
|
ThresholdN int
|
||||||
ThresholdT int
|
ThresholdT int
|
||||||
Status value_objects.AccountStatus
|
Status value_objects.AccountStatus
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
LastLoginAt *time.Time
|
LastLoginAt *time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccount creates a new Account
|
// NewAccount creates a new Account
|
||||||
func NewAccount(
|
func NewAccount(
|
||||||
username string,
|
username string,
|
||||||
email string,
|
email string,
|
||||||
publicKey []byte,
|
publicKey []byte,
|
||||||
keygenSessionID uuid.UUID,
|
keygenSessionID uuid.UUID,
|
||||||
thresholdN int,
|
thresholdN int,
|
||||||
thresholdT int,
|
thresholdT int,
|
||||||
) *Account {
|
) *Account {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
var emailPtr *string
|
var emailPtr *string
|
||||||
|
|
||||||
if email != "" {
|
if email != "" {
|
||||||
emailPtr = &email
|
emailPtr = &email
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Account{
|
return &Account{
|
||||||
ID: value_objects.NewAccountID(),
|
ID: value_objects.NewAccountID(),
|
||||||
Username: username,
|
Username: username,
|
||||||
Email: emailPtr,
|
Email: emailPtr,
|
||||||
PublicKey: publicKey,
|
PublicKey: publicKey,
|
||||||
KeygenSessionID: keygenSessionID,
|
KeygenSessionID: keygenSessionID,
|
||||||
ThresholdN: thresholdN,
|
ThresholdN: thresholdN,
|
||||||
ThresholdT: thresholdT,
|
ThresholdT: thresholdT,
|
||||||
Status: value_objects.AccountStatusActive,
|
Status: value_objects.AccountStatusActive,
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
UpdatedAt: now,
|
UpdatedAt: now,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetPhone sets the phone number
|
// SetPhone sets the phone number
|
||||||
func (a *Account) SetPhone(phone string) {
|
func (a *Account) SetPhone(phone string) {
|
||||||
a.Phone = &phone
|
a.Phone = &phone
|
||||||
a.UpdatedAt = time.Now().UTC()
|
a.UpdatedAt = time.Now().UTC()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLastLogin updates the last login timestamp
|
// UpdateLastLogin updates the last login timestamp
|
||||||
func (a *Account) UpdateLastLogin() {
|
func (a *Account) UpdateLastLogin() {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
a.LastLoginAt = &now
|
a.LastLoginAt = &now
|
||||||
a.UpdatedAt = now
|
a.UpdatedAt = now
|
||||||
}
|
}
|
||||||
|
|
||||||
// Suspend suspends the account
|
// Suspend suspends the account
|
||||||
func (a *Account) Suspend() error {
|
func (a *Account) Suspend() error {
|
||||||
if a.Status == value_objects.AccountStatusRecovering {
|
if a.Status == value_objects.AccountStatusRecovering {
|
||||||
return ErrAccountInRecovery
|
return ErrAccountInRecovery
|
||||||
}
|
}
|
||||||
a.Status = value_objects.AccountStatusSuspended
|
a.Status = value_objects.AccountStatusSuspended
|
||||||
a.UpdatedAt = time.Now().UTC()
|
a.UpdatedAt = time.Now().UTC()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lock locks the account
|
// Lock locks the account
|
||||||
func (a *Account) Lock() error {
|
func (a *Account) Lock() error {
|
||||||
if a.Status == value_objects.AccountStatusRecovering {
|
if a.Status == value_objects.AccountStatusRecovering {
|
||||||
return ErrAccountInRecovery
|
return ErrAccountInRecovery
|
||||||
}
|
}
|
||||||
a.Status = value_objects.AccountStatusLocked
|
a.Status = value_objects.AccountStatusLocked
|
||||||
a.UpdatedAt = time.Now().UTC()
|
a.UpdatedAt = time.Now().UTC()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Activate activates the account
|
// Activate activates the account
|
||||||
func (a *Account) Activate() {
|
func (a *Account) Activate() {
|
||||||
a.Status = value_objects.AccountStatusActive
|
a.Status = value_objects.AccountStatusActive
|
||||||
a.UpdatedAt = time.Now().UTC()
|
a.UpdatedAt = time.Now().UTC()
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartRecovery marks the account as recovering
|
// StartRecovery marks the account as recovering
|
||||||
func (a *Account) StartRecovery() error {
|
func (a *Account) StartRecovery() error {
|
||||||
if !a.Status.CanInitiateRecovery() {
|
if !a.Status.CanInitiateRecovery() {
|
||||||
return ErrCannotInitiateRecovery
|
return ErrCannotInitiateRecovery
|
||||||
}
|
}
|
||||||
a.Status = value_objects.AccountStatusRecovering
|
a.Status = value_objects.AccountStatusRecovering
|
||||||
a.UpdatedAt = time.Now().UTC()
|
a.UpdatedAt = time.Now().UTC()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CompleteRecovery completes the recovery process with new public key
|
// CompleteRecovery completes the recovery process with new public key
|
||||||
func (a *Account) CompleteRecovery(newPublicKey []byte, newKeygenSessionID uuid.UUID) {
|
func (a *Account) CompleteRecovery(newPublicKey []byte, newKeygenSessionID uuid.UUID) {
|
||||||
a.PublicKey = newPublicKey
|
a.PublicKey = newPublicKey
|
||||||
a.KeygenSessionID = newKeygenSessionID
|
a.KeygenSessionID = newKeygenSessionID
|
||||||
a.Status = value_objects.AccountStatusActive
|
a.Status = value_objects.AccountStatusActive
|
||||||
a.UpdatedAt = time.Now().UTC()
|
a.UpdatedAt = time.Now().UTC()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CanLogin checks if the account can login
|
// CanLogin checks if the account can login
|
||||||
func (a *Account) CanLogin() bool {
|
func (a *Account) CanLogin() bool {
|
||||||
return a.Status.CanLogin()
|
return a.Status.CanLogin()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsActive checks if the account is active
|
// IsActive checks if the account is active
|
||||||
func (a *Account) IsActive() bool {
|
func (a *Account) IsActive() bool {
|
||||||
return a.Status == value_objects.AccountStatusActive
|
return a.Status == value_objects.AccountStatusActive
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate validates the account data
|
// Validate validates the account data
|
||||||
func (a *Account) Validate() error {
|
func (a *Account) Validate() error {
|
||||||
if a.Username == "" {
|
if a.Username == "" {
|
||||||
return ErrInvalidUsername
|
return ErrInvalidUsername
|
||||||
}
|
}
|
||||||
// Email is optional, but if provided must be valid (checked by binding)
|
// Email is optional, but if provided must be valid (checked by binding)
|
||||||
if len(a.PublicKey) == 0 {
|
if len(a.PublicKey) == 0 {
|
||||||
return ErrInvalidPublicKey
|
return ErrInvalidPublicKey
|
||||||
}
|
}
|
||||||
if a.ThresholdT > a.ThresholdN || a.ThresholdT <= 0 {
|
if a.ThresholdT > a.ThresholdN || a.ThresholdT <= 0 {
|
||||||
return ErrInvalidThreshold
|
return ErrInvalidThreshold
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Account errors
|
// Account errors
|
||||||
var (
|
var (
|
||||||
ErrInvalidUsername = &AccountError{Code: "INVALID_USERNAME", Message: "username is required"}
|
ErrInvalidUsername = &AccountError{Code: "INVALID_USERNAME", Message: "username is required"}
|
||||||
ErrInvalidEmail = &AccountError{Code: "INVALID_EMAIL", Message: "email is required"}
|
ErrInvalidEmail = &AccountError{Code: "INVALID_EMAIL", Message: "email is required"}
|
||||||
ErrInvalidPublicKey = &AccountError{Code: "INVALID_PUBLIC_KEY", Message: "public key is required"}
|
ErrInvalidPublicKey = &AccountError{Code: "INVALID_PUBLIC_KEY", Message: "public key is required"}
|
||||||
ErrInvalidThreshold = &AccountError{Code: "INVALID_THRESHOLD", Message: "invalid threshold configuration"}
|
ErrInvalidThreshold = &AccountError{Code: "INVALID_THRESHOLD", Message: "invalid threshold configuration"}
|
||||||
ErrAccountInRecovery = &AccountError{Code: "ACCOUNT_IN_RECOVERY", Message: "account is in recovery mode"}
|
ErrAccountInRecovery = &AccountError{Code: "ACCOUNT_IN_RECOVERY", Message: "account is in recovery mode"}
|
||||||
ErrCannotInitiateRecovery = &AccountError{Code: "CANNOT_INITIATE_RECOVERY", Message: "cannot initiate recovery in current state"}
|
ErrCannotInitiateRecovery = &AccountError{Code: "CANNOT_INITIATE_RECOVERY", Message: "cannot initiate recovery in current state"}
|
||||||
ErrAccountNotActive = &AccountError{Code: "ACCOUNT_NOT_ACTIVE", Message: "account is not active"}
|
ErrAccountNotActive = &AccountError{Code: "ACCOUNT_NOT_ACTIVE", Message: "account is not active"}
|
||||||
ErrAccountNotFound = &AccountError{Code: "ACCOUNT_NOT_FOUND", Message: "account not found"}
|
ErrAccountNotFound = &AccountError{Code: "ACCOUNT_NOT_FOUND", Message: "account not found"}
|
||||||
ErrDuplicateUsername = &AccountError{Code: "DUPLICATE_USERNAME", Message: "username already exists"}
|
ErrDuplicateUsername = &AccountError{Code: "DUPLICATE_USERNAME", Message: "username already exists"}
|
||||||
ErrDuplicateEmail = &AccountError{Code: "DUPLICATE_EMAIL", Message: "email already exists"}
|
ErrDuplicateEmail = &AccountError{Code: "DUPLICATE_EMAIL", Message: "email already exists"}
|
||||||
)
|
)
|
||||||
|
|
||||||
// AccountError represents an account domain error
|
// AccountError represents an account domain error
|
||||||
type AccountError struct {
|
type AccountError struct {
|
||||||
Code string
|
Code string
|
||||||
Message string
|
Message string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *AccountError) Error() string {
|
func (e *AccountError) Error() string {
|
||||||
return e.Message
|
return e.Message
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,104 +1,104 @@
|
||||||
package entities
|
package entities
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AccountShare represents a mapping of key share to account
|
// AccountShare represents a mapping of key share to account
|
||||||
// Note: This records share location, not share content
|
// Note: This records share location, not share content
|
||||||
type AccountShare struct {
|
type AccountShare struct {
|
||||||
ID uuid.UUID
|
ID uuid.UUID
|
||||||
AccountID value_objects.AccountID
|
AccountID value_objects.AccountID
|
||||||
ShareType value_objects.ShareType
|
ShareType value_objects.ShareType
|
||||||
PartyID string
|
PartyID string
|
||||||
PartyIndex int
|
PartyIndex int
|
||||||
DeviceType *string
|
DeviceType *string
|
||||||
DeviceID *string
|
DeviceID *string
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
LastUsedAt *time.Time
|
LastUsedAt *time.Time
|
||||||
IsActive bool
|
IsActive bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountShare creates a new AccountShare
|
// NewAccountShare creates a new AccountShare
|
||||||
func NewAccountShare(
|
func NewAccountShare(
|
||||||
accountID value_objects.AccountID,
|
accountID value_objects.AccountID,
|
||||||
shareType value_objects.ShareType,
|
shareType value_objects.ShareType,
|
||||||
partyID string,
|
partyID string,
|
||||||
partyIndex int,
|
partyIndex int,
|
||||||
) *AccountShare {
|
) *AccountShare {
|
||||||
return &AccountShare{
|
return &AccountShare{
|
||||||
ID: uuid.New(),
|
ID: uuid.New(),
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
ShareType: shareType,
|
ShareType: shareType,
|
||||||
PartyID: partyID,
|
PartyID: partyID,
|
||||||
PartyIndex: partyIndex,
|
PartyIndex: partyIndex,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
IsActive: true,
|
IsActive: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetDeviceInfo sets device information for user device shares
|
// SetDeviceInfo sets device information for user device shares
|
||||||
func (s *AccountShare) SetDeviceInfo(deviceType, deviceID string) {
|
func (s *AccountShare) SetDeviceInfo(deviceType, deviceID string) {
|
||||||
s.DeviceType = &deviceType
|
s.DeviceType = &deviceType
|
||||||
s.DeviceID = &deviceID
|
s.DeviceID = &deviceID
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLastUsed updates the last used timestamp
|
// UpdateLastUsed updates the last used timestamp
|
||||||
func (s *AccountShare) UpdateLastUsed() {
|
func (s *AccountShare) UpdateLastUsed() {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
s.LastUsedAt = &now
|
s.LastUsedAt = &now
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deactivate deactivates the share (e.g., when device is lost)
|
// Deactivate deactivates the share (e.g., when device is lost)
|
||||||
func (s *AccountShare) Deactivate() {
|
func (s *AccountShare) Deactivate() {
|
||||||
s.IsActive = false
|
s.IsActive = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Activate activates the share
|
// Activate activates the share
|
||||||
func (s *AccountShare) Activate() {
|
func (s *AccountShare) Activate() {
|
||||||
s.IsActive = true
|
s.IsActive = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsUserDeviceShare checks if this is a user device share
|
// IsUserDeviceShare checks if this is a user device share
|
||||||
func (s *AccountShare) IsUserDeviceShare() bool {
|
func (s *AccountShare) IsUserDeviceShare() bool {
|
||||||
return s.ShareType == value_objects.ShareTypeUserDevice
|
return s.ShareType == value_objects.ShareTypeUserDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsServerShare checks if this is a server share
|
// IsServerShare checks if this is a server share
|
||||||
func (s *AccountShare) IsServerShare() bool {
|
func (s *AccountShare) IsServerShare() bool {
|
||||||
return s.ShareType == value_objects.ShareTypeServer
|
return s.ShareType == value_objects.ShareTypeServer
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsRecoveryShare checks if this is a recovery share
|
// IsRecoveryShare checks if this is a recovery share
|
||||||
func (s *AccountShare) IsRecoveryShare() bool {
|
func (s *AccountShare) IsRecoveryShare() bool {
|
||||||
return s.ShareType == value_objects.ShareTypeRecovery
|
return s.ShareType == value_objects.ShareTypeRecovery
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate validates the account share
|
// Validate validates the account share
|
||||||
func (s *AccountShare) Validate() error {
|
func (s *AccountShare) Validate() error {
|
||||||
if s.AccountID.IsZero() {
|
if s.AccountID.IsZero() {
|
||||||
return ErrShareInvalidAccountID
|
return ErrShareInvalidAccountID
|
||||||
}
|
}
|
||||||
if !s.ShareType.IsValid() {
|
if !s.ShareType.IsValid() {
|
||||||
return ErrShareInvalidType
|
return ErrShareInvalidType
|
||||||
}
|
}
|
||||||
if s.PartyID == "" {
|
if s.PartyID == "" {
|
||||||
return ErrShareInvalidPartyID
|
return ErrShareInvalidPartyID
|
||||||
}
|
}
|
||||||
if s.PartyIndex < 0 {
|
if s.PartyIndex < 0 {
|
||||||
return ErrShareInvalidPartyIndex
|
return ErrShareInvalidPartyIndex
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountShare errors
|
// AccountShare errors
|
||||||
var (
|
var (
|
||||||
ErrShareInvalidAccountID = &AccountError{Code: "SHARE_INVALID_ACCOUNT_ID", Message: "invalid account ID"}
|
ErrShareInvalidAccountID = &AccountError{Code: "SHARE_INVALID_ACCOUNT_ID", Message: "invalid account ID"}
|
||||||
ErrShareInvalidType = &AccountError{Code: "SHARE_INVALID_TYPE", Message: "invalid share type"}
|
ErrShareInvalidType = &AccountError{Code: "SHARE_INVALID_TYPE", Message: "invalid share type"}
|
||||||
ErrShareInvalidPartyID = &AccountError{Code: "SHARE_INVALID_PARTY_ID", Message: "invalid party ID"}
|
ErrShareInvalidPartyID = &AccountError{Code: "SHARE_INVALID_PARTY_ID", Message: "invalid party ID"}
|
||||||
ErrShareInvalidPartyIndex = &AccountError{Code: "SHARE_INVALID_PARTY_INDEX", Message: "invalid party index"}
|
ErrShareInvalidPartyIndex = &AccountError{Code: "SHARE_INVALID_PARTY_INDEX", Message: "invalid party index"}
|
||||||
ErrShareNotFound = &AccountError{Code: "SHARE_NOT_FOUND", Message: "share not found"}
|
ErrShareNotFound = &AccountError{Code: "SHARE_NOT_FOUND", Message: "share not found"}
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,104 +1,104 @@
|
||||||
package entities
|
package entities
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RecoverySession represents an account recovery session
|
// RecoverySession represents an account recovery session
|
||||||
type RecoverySession struct {
|
type RecoverySession struct {
|
||||||
ID uuid.UUID
|
ID uuid.UUID
|
||||||
AccountID value_objects.AccountID
|
AccountID value_objects.AccountID
|
||||||
RecoveryType value_objects.RecoveryType
|
RecoveryType value_objects.RecoveryType
|
||||||
OldShareType *value_objects.ShareType
|
OldShareType *value_objects.ShareType
|
||||||
NewKeygenSessionID *uuid.UUID
|
NewKeygenSessionID *uuid.UUID
|
||||||
Status value_objects.RecoveryStatus
|
Status value_objects.RecoveryStatus
|
||||||
RequestedAt time.Time
|
RequestedAt time.Time
|
||||||
CompletedAt *time.Time
|
CompletedAt *time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRecoverySession creates a new RecoverySession
|
// NewRecoverySession creates a new RecoverySession
|
||||||
func NewRecoverySession(
|
func NewRecoverySession(
|
||||||
accountID value_objects.AccountID,
|
accountID value_objects.AccountID,
|
||||||
recoveryType value_objects.RecoveryType,
|
recoveryType value_objects.RecoveryType,
|
||||||
) *RecoverySession {
|
) *RecoverySession {
|
||||||
return &RecoverySession{
|
return &RecoverySession{
|
||||||
ID: uuid.New(),
|
ID: uuid.New(),
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
RecoveryType: recoveryType,
|
RecoveryType: recoveryType,
|
||||||
Status: value_objects.RecoveryStatusRequested,
|
Status: value_objects.RecoveryStatusRequested,
|
||||||
RequestedAt: time.Now().UTC(),
|
RequestedAt: time.Now().UTC(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetOldShareType sets the old share type being replaced
|
// SetOldShareType sets the old share type being replaced
|
||||||
func (r *RecoverySession) SetOldShareType(shareType value_objects.ShareType) {
|
func (r *RecoverySession) SetOldShareType(shareType value_objects.ShareType) {
|
||||||
r.OldShareType = &shareType
|
r.OldShareType = &shareType
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartKeygen starts the keygen process for recovery
|
// StartKeygen starts the keygen process for recovery
|
||||||
func (r *RecoverySession) StartKeygen(keygenSessionID uuid.UUID) error {
|
func (r *RecoverySession) StartKeygen(keygenSessionID uuid.UUID) error {
|
||||||
if r.Status != value_objects.RecoveryStatusRequested {
|
if r.Status != value_objects.RecoveryStatusRequested {
|
||||||
return ErrRecoveryInvalidState
|
return ErrRecoveryInvalidState
|
||||||
}
|
}
|
||||||
r.NewKeygenSessionID = &keygenSessionID
|
r.NewKeygenSessionID = &keygenSessionID
|
||||||
r.Status = value_objects.RecoveryStatusInProgress
|
r.Status = value_objects.RecoveryStatusInProgress
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Complete marks the recovery as completed
|
// Complete marks the recovery as completed
|
||||||
func (r *RecoverySession) Complete() error {
|
func (r *RecoverySession) Complete() error {
|
||||||
if r.Status != value_objects.RecoveryStatusInProgress {
|
if r.Status != value_objects.RecoveryStatusInProgress {
|
||||||
return ErrRecoveryInvalidState
|
return ErrRecoveryInvalidState
|
||||||
}
|
}
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
r.CompletedAt = &now
|
r.CompletedAt = &now
|
||||||
r.Status = value_objects.RecoveryStatusCompleted
|
r.Status = value_objects.RecoveryStatusCompleted
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fail marks the recovery as failed
|
// Fail marks the recovery as failed
|
||||||
func (r *RecoverySession) Fail() error {
|
func (r *RecoverySession) Fail() error {
|
||||||
if r.Status == value_objects.RecoveryStatusCompleted {
|
if r.Status == value_objects.RecoveryStatusCompleted {
|
||||||
return ErrRecoveryAlreadyCompleted
|
return ErrRecoveryAlreadyCompleted
|
||||||
}
|
}
|
||||||
r.Status = value_objects.RecoveryStatusFailed
|
r.Status = value_objects.RecoveryStatusFailed
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsCompleted checks if recovery is completed
|
// IsCompleted checks if recovery is completed
|
||||||
func (r *RecoverySession) IsCompleted() bool {
|
func (r *RecoverySession) IsCompleted() bool {
|
||||||
return r.Status == value_objects.RecoveryStatusCompleted
|
return r.Status == value_objects.RecoveryStatusCompleted
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsFailed checks if recovery failed
|
// IsFailed checks if recovery failed
|
||||||
func (r *RecoverySession) IsFailed() bool {
|
func (r *RecoverySession) IsFailed() bool {
|
||||||
return r.Status == value_objects.RecoveryStatusFailed
|
return r.Status == value_objects.RecoveryStatusFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsInProgress checks if recovery is in progress
|
// IsInProgress checks if recovery is in progress
|
||||||
func (r *RecoverySession) IsInProgress() bool {
|
func (r *RecoverySession) IsInProgress() bool {
|
||||||
return r.Status == value_objects.RecoveryStatusInProgress
|
return r.Status == value_objects.RecoveryStatusInProgress
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate validates the recovery session
|
// Validate validates the recovery session
|
||||||
func (r *RecoverySession) Validate() error {
|
func (r *RecoverySession) Validate() error {
|
||||||
if r.AccountID.IsZero() {
|
if r.AccountID.IsZero() {
|
||||||
return ErrRecoveryInvalidAccountID
|
return ErrRecoveryInvalidAccountID
|
||||||
}
|
}
|
||||||
if !r.RecoveryType.IsValid() {
|
if !r.RecoveryType.IsValid() {
|
||||||
return ErrRecoveryInvalidType
|
return ErrRecoveryInvalidType
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recovery errors
|
// Recovery errors
|
||||||
var (
|
var (
|
||||||
ErrRecoveryInvalidAccountID = &AccountError{Code: "RECOVERY_INVALID_ACCOUNT_ID", Message: "invalid account ID for recovery"}
|
ErrRecoveryInvalidAccountID = &AccountError{Code: "RECOVERY_INVALID_ACCOUNT_ID", Message: "invalid account ID for recovery"}
|
||||||
ErrRecoveryInvalidType = &AccountError{Code: "RECOVERY_INVALID_TYPE", Message: "invalid recovery type"}
|
ErrRecoveryInvalidType = &AccountError{Code: "RECOVERY_INVALID_TYPE", Message: "invalid recovery type"}
|
||||||
ErrRecoveryInvalidState = &AccountError{Code: "RECOVERY_INVALID_STATE", Message: "invalid recovery state for this operation"}
|
ErrRecoveryInvalidState = &AccountError{Code: "RECOVERY_INVALID_STATE", Message: "invalid recovery state for this operation"}
|
||||||
ErrRecoveryAlreadyCompleted = &AccountError{Code: "RECOVERY_ALREADY_COMPLETED", Message: "recovery already completed"}
|
ErrRecoveryAlreadyCompleted = &AccountError{Code: "RECOVERY_ALREADY_COMPLETED", Message: "recovery already completed"}
|
||||||
ErrRecoveryNotFound = &AccountError{Code: "RECOVERY_NOT_FOUND", Message: "recovery session not found"}
|
ErrRecoveryNotFound = &AccountError{Code: "RECOVERY_NOT_FOUND", Message: "recovery session not found"}
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,95 +1,95 @@
|
||||||
package repositories
|
package repositories
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AccountRepository defines the interface for account persistence
|
// AccountRepository defines the interface for account persistence
|
||||||
type AccountRepository interface {
|
type AccountRepository interface {
|
||||||
// Create creates a new account
|
// Create creates a new account
|
||||||
Create(ctx context.Context, account *entities.Account) error
|
Create(ctx context.Context, account *entities.Account) error
|
||||||
|
|
||||||
// GetByID retrieves an account by ID
|
// GetByID retrieves an account by ID
|
||||||
GetByID(ctx context.Context, id value_objects.AccountID) (*entities.Account, error)
|
GetByID(ctx context.Context, id value_objects.AccountID) (*entities.Account, error)
|
||||||
|
|
||||||
// GetByUsername retrieves an account by username
|
// GetByUsername retrieves an account by username
|
||||||
GetByUsername(ctx context.Context, username string) (*entities.Account, error)
|
GetByUsername(ctx context.Context, username string) (*entities.Account, error)
|
||||||
|
|
||||||
// GetByEmail retrieves an account by email
|
// GetByEmail retrieves an account by email
|
||||||
GetByEmail(ctx context.Context, email string) (*entities.Account, error)
|
GetByEmail(ctx context.Context, email string) (*entities.Account, error)
|
||||||
|
|
||||||
// GetByPublicKey retrieves an account by public key
|
// GetByPublicKey retrieves an account by public key
|
||||||
GetByPublicKey(ctx context.Context, publicKey []byte) (*entities.Account, error)
|
GetByPublicKey(ctx context.Context, publicKey []byte) (*entities.Account, error)
|
||||||
|
|
||||||
// Update updates an existing account
|
// Update updates an existing account
|
||||||
Update(ctx context.Context, account *entities.Account) error
|
Update(ctx context.Context, account *entities.Account) error
|
||||||
|
|
||||||
// Delete deletes an account
|
// Delete deletes an account
|
||||||
Delete(ctx context.Context, id value_objects.AccountID) error
|
Delete(ctx context.Context, id value_objects.AccountID) error
|
||||||
|
|
||||||
// ExistsByUsername checks if username exists
|
// ExistsByUsername checks if username exists
|
||||||
ExistsByUsername(ctx context.Context, username string) (bool, error)
|
ExistsByUsername(ctx context.Context, username string) (bool, error)
|
||||||
|
|
||||||
// ExistsByEmail checks if email exists
|
// ExistsByEmail checks if email exists
|
||||||
ExistsByEmail(ctx context.Context, email string) (bool, error)
|
ExistsByEmail(ctx context.Context, email string) (bool, error)
|
||||||
|
|
||||||
// List lists accounts with pagination
|
// List lists accounts with pagination
|
||||||
List(ctx context.Context, offset, limit int) ([]*entities.Account, error)
|
List(ctx context.Context, offset, limit int) ([]*entities.Account, error)
|
||||||
|
|
||||||
// Count returns the total number of accounts
|
// Count returns the total number of accounts
|
||||||
Count(ctx context.Context) (int64, error)
|
Count(ctx context.Context) (int64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountShareRepository defines the interface for account share persistence
|
// AccountShareRepository defines the interface for account share persistence
|
||||||
type AccountShareRepository interface {
|
type AccountShareRepository interface {
|
||||||
// Create creates a new account share
|
// Create creates a new account share
|
||||||
Create(ctx context.Context, share *entities.AccountShare) error
|
Create(ctx context.Context, share *entities.AccountShare) error
|
||||||
|
|
||||||
// GetByID retrieves a share by ID
|
// GetByID retrieves a share by ID
|
||||||
GetByID(ctx context.Context, id string) (*entities.AccountShare, error)
|
GetByID(ctx context.Context, id string) (*entities.AccountShare, error)
|
||||||
|
|
||||||
// GetByAccountID retrieves all shares for an account
|
// GetByAccountID retrieves all shares for an account
|
||||||
GetByAccountID(ctx context.Context, accountID value_objects.AccountID) ([]*entities.AccountShare, error)
|
GetByAccountID(ctx context.Context, accountID value_objects.AccountID) ([]*entities.AccountShare, error)
|
||||||
|
|
||||||
// GetActiveByAccountID retrieves active shares for an account
|
// GetActiveByAccountID retrieves active shares for an account
|
||||||
GetActiveByAccountID(ctx context.Context, accountID value_objects.AccountID) ([]*entities.AccountShare, error)
|
GetActiveByAccountID(ctx context.Context, accountID value_objects.AccountID) ([]*entities.AccountShare, error)
|
||||||
|
|
||||||
// GetByPartyID retrieves shares by party ID
|
// GetByPartyID retrieves shares by party ID
|
||||||
GetByPartyID(ctx context.Context, partyID string) ([]*entities.AccountShare, error)
|
GetByPartyID(ctx context.Context, partyID string) ([]*entities.AccountShare, error)
|
||||||
|
|
||||||
// Update updates a share
|
// Update updates a share
|
||||||
Update(ctx context.Context, share *entities.AccountShare) error
|
Update(ctx context.Context, share *entities.AccountShare) error
|
||||||
|
|
||||||
// Delete deletes a share
|
// Delete deletes a share
|
||||||
Delete(ctx context.Context, id string) error
|
Delete(ctx context.Context, id string) error
|
||||||
|
|
||||||
// DeactivateByAccountID deactivates all shares for an account
|
// DeactivateByAccountID deactivates all shares for an account
|
||||||
DeactivateByAccountID(ctx context.Context, accountID value_objects.AccountID) error
|
DeactivateByAccountID(ctx context.Context, accountID value_objects.AccountID) error
|
||||||
|
|
||||||
// DeactivateByShareType deactivates shares of a specific type for an account
|
// DeactivateByShareType deactivates shares of a specific type for an account
|
||||||
DeactivateByShareType(ctx context.Context, accountID value_objects.AccountID, shareType value_objects.ShareType) error
|
DeactivateByShareType(ctx context.Context, accountID value_objects.AccountID, shareType value_objects.ShareType) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecoverySessionRepository defines the interface for recovery session persistence
|
// RecoverySessionRepository defines the interface for recovery session persistence
|
||||||
type RecoverySessionRepository interface {
|
type RecoverySessionRepository interface {
|
||||||
// Create creates a new recovery session
|
// Create creates a new recovery session
|
||||||
Create(ctx context.Context, session *entities.RecoverySession) error
|
Create(ctx context.Context, session *entities.RecoverySession) error
|
||||||
|
|
||||||
// GetByID retrieves a recovery session by ID
|
// GetByID retrieves a recovery session by ID
|
||||||
GetByID(ctx context.Context, id string) (*entities.RecoverySession, error)
|
GetByID(ctx context.Context, id string) (*entities.RecoverySession, error)
|
||||||
|
|
||||||
// GetByAccountID retrieves recovery sessions for an account
|
// GetByAccountID retrieves recovery sessions for an account
|
||||||
GetByAccountID(ctx context.Context, accountID value_objects.AccountID) ([]*entities.RecoverySession, error)
|
GetByAccountID(ctx context.Context, accountID value_objects.AccountID) ([]*entities.RecoverySession, error)
|
||||||
|
|
||||||
// GetActiveByAccountID retrieves active recovery sessions for an account
|
// GetActiveByAccountID retrieves active recovery sessions for an account
|
||||||
GetActiveByAccountID(ctx context.Context, accountID value_objects.AccountID) (*entities.RecoverySession, error)
|
GetActiveByAccountID(ctx context.Context, accountID value_objects.AccountID) (*entities.RecoverySession, error)
|
||||||
|
|
||||||
// Update updates a recovery session
|
// Update updates a recovery session
|
||||||
Update(ctx context.Context, session *entities.RecoverySession) error
|
Update(ctx context.Context, session *entities.RecoverySession) error
|
||||||
|
|
||||||
// Delete deletes a recovery session
|
// Delete deletes a recovery session
|
||||||
Delete(ctx context.Context, id string) error
|
Delete(ctx context.Context, id string) error
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,272 +1,272 @@
|
||||||
package services
|
package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rwadurian/mpc-system/pkg/crypto"
|
"github.com/rwadurian/mpc-system/pkg/crypto"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/repositories"
|
"github.com/rwadurian/mpc-system/services/account/domain/repositories"
|
||||||
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AccountDomainService provides domain logic for accounts
|
// AccountDomainService provides domain logic for accounts
|
||||||
type AccountDomainService struct {
|
type AccountDomainService struct {
|
||||||
accountRepo repositories.AccountRepository
|
accountRepo repositories.AccountRepository
|
||||||
shareRepo repositories.AccountShareRepository
|
shareRepo repositories.AccountShareRepository
|
||||||
recoveryRepo repositories.RecoverySessionRepository
|
recoveryRepo repositories.RecoverySessionRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountDomainService creates a new AccountDomainService
|
// NewAccountDomainService creates a new AccountDomainService
|
||||||
func NewAccountDomainService(
|
func NewAccountDomainService(
|
||||||
accountRepo repositories.AccountRepository,
|
accountRepo repositories.AccountRepository,
|
||||||
shareRepo repositories.AccountShareRepository,
|
shareRepo repositories.AccountShareRepository,
|
||||||
recoveryRepo repositories.RecoverySessionRepository,
|
recoveryRepo repositories.RecoverySessionRepository,
|
||||||
) *AccountDomainService {
|
) *AccountDomainService {
|
||||||
return &AccountDomainService{
|
return &AccountDomainService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
shareRepo: shareRepo,
|
shareRepo: shareRepo,
|
||||||
recoveryRepo: recoveryRepo,
|
recoveryRepo: recoveryRepo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAccountInput represents input for creating an account
|
// CreateAccountInput represents input for creating an account
|
||||||
type CreateAccountInput struct {
|
type CreateAccountInput struct {
|
||||||
Username string
|
Username string
|
||||||
Email string
|
Email string
|
||||||
Phone *string
|
Phone *string
|
||||||
PublicKey []byte
|
PublicKey []byte
|
||||||
KeygenSessionID uuid.UUID
|
KeygenSessionID uuid.UUID
|
||||||
ThresholdN int
|
ThresholdN int
|
||||||
ThresholdT int
|
ThresholdT int
|
||||||
Shares []ShareInfo
|
Shares []ShareInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShareInfo represents information about a key share
|
// ShareInfo represents information about a key share
|
||||||
type ShareInfo struct {
|
type ShareInfo struct {
|
||||||
ShareType value_objects.ShareType
|
ShareType value_objects.ShareType
|
||||||
PartyID string
|
PartyID string
|
||||||
PartyIndex int
|
PartyIndex int
|
||||||
DeviceType *string
|
DeviceType *string
|
||||||
DeviceID *string
|
DeviceID *string
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAccount creates a new account with shares
|
// CreateAccount creates a new account with shares
|
||||||
func (s *AccountDomainService) CreateAccount(ctx context.Context, input CreateAccountInput) (*entities.Account, error) {
|
func (s *AccountDomainService) CreateAccount(ctx context.Context, input CreateAccountInput) (*entities.Account, error) {
|
||||||
// Check username uniqueness
|
// Check username uniqueness
|
||||||
exists, err := s.accountRepo.ExistsByUsername(ctx, input.Username)
|
exists, err := s.accountRepo.ExistsByUsername(ctx, input.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if exists {
|
if exists {
|
||||||
return nil, entities.ErrDuplicateUsername
|
return nil, entities.ErrDuplicateUsername
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check email uniqueness
|
// Check email uniqueness
|
||||||
exists, err = s.accountRepo.ExistsByEmail(ctx, input.Email)
|
exists, err = s.accountRepo.ExistsByEmail(ctx, input.Email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if exists {
|
if exists {
|
||||||
return nil, entities.ErrDuplicateEmail
|
return nil, entities.ErrDuplicateEmail
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create account
|
// Create account
|
||||||
account := entities.NewAccount(
|
account := entities.NewAccount(
|
||||||
input.Username,
|
input.Username,
|
||||||
input.Email,
|
input.Email,
|
||||||
input.PublicKey,
|
input.PublicKey,
|
||||||
input.KeygenSessionID,
|
input.KeygenSessionID,
|
||||||
input.ThresholdN,
|
input.ThresholdN,
|
||||||
input.ThresholdT,
|
input.ThresholdT,
|
||||||
)
|
)
|
||||||
|
|
||||||
if input.Phone != nil {
|
if input.Phone != nil {
|
||||||
account.SetPhone(*input.Phone)
|
account.SetPhone(*input.Phone)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate account
|
// Validate account
|
||||||
if err := account.Validate(); err != nil {
|
if err := account.Validate(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create account in repository
|
// Create account in repository
|
||||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create shares
|
// Create shares
|
||||||
for _, shareInfo := range input.Shares {
|
for _, shareInfo := range input.Shares {
|
||||||
share := entities.NewAccountShare(
|
share := entities.NewAccountShare(
|
||||||
account.ID,
|
account.ID,
|
||||||
shareInfo.ShareType,
|
shareInfo.ShareType,
|
||||||
shareInfo.PartyID,
|
shareInfo.PartyID,
|
||||||
shareInfo.PartyIndex,
|
shareInfo.PartyIndex,
|
||||||
)
|
)
|
||||||
|
|
||||||
if shareInfo.DeviceType != nil && shareInfo.DeviceID != nil {
|
if shareInfo.DeviceType != nil && shareInfo.DeviceID != nil {
|
||||||
share.SetDeviceInfo(*shareInfo.DeviceType, *shareInfo.DeviceID)
|
share.SetDeviceInfo(*shareInfo.DeviceType, *shareInfo.DeviceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := share.Validate(); err != nil {
|
if err := share.Validate(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.shareRepo.Create(ctx, share); err != nil {
|
if err := s.shareRepo.Create(ctx, share); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifySignature verifies a signature against an account's public key
|
// VerifySignature verifies a signature against an account's public key
|
||||||
func (s *AccountDomainService) VerifySignature(ctx context.Context, accountID value_objects.AccountID, message, signature []byte) (bool, error) {
|
func (s *AccountDomainService) VerifySignature(ctx context.Context, accountID value_objects.AccountID, message, signature []byte) (bool, error) {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse public key
|
// Parse public key
|
||||||
pubKey, err := crypto.ParsePublicKey(account.PublicKey)
|
pubKey, err := crypto.ParsePublicKey(account.PublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify signature
|
// Verify signature
|
||||||
valid := crypto.VerifySignature(pubKey, message, signature)
|
valid := crypto.VerifySignature(pubKey, message, signature)
|
||||||
return valid, nil
|
return valid, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitiateRecovery initiates account recovery
|
// InitiateRecovery initiates account recovery
|
||||||
func (s *AccountDomainService) InitiateRecovery(ctx context.Context, accountID value_objects.AccountID, recoveryType value_objects.RecoveryType, oldShareType *value_objects.ShareType) (*entities.RecoverySession, error) {
|
func (s *AccountDomainService) InitiateRecovery(ctx context.Context, accountID value_objects.AccountID, recoveryType value_objects.RecoveryType, oldShareType *value_objects.ShareType) (*entities.RecoverySession, error) {
|
||||||
// Get account
|
// Get account
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if recovery can be initiated
|
// Check if recovery can be initiated
|
||||||
if err := account.StartRecovery(); err != nil {
|
if err := account.StartRecovery(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update account status
|
// Update account status
|
||||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create recovery session
|
// Create recovery session
|
||||||
recoverySession := entities.NewRecoverySession(accountID, recoveryType)
|
recoverySession := entities.NewRecoverySession(accountID, recoveryType)
|
||||||
if oldShareType != nil {
|
if oldShareType != nil {
|
||||||
recoverySession.SetOldShareType(*oldShareType)
|
recoverySession.SetOldShareType(*oldShareType)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := recoverySession.Validate(); err != nil {
|
if err := recoverySession.Validate(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.recoveryRepo.Create(ctx, recoverySession); err != nil {
|
if err := s.recoveryRepo.Create(ctx, recoverySession); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return recoverySession, nil
|
return recoverySession, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CompleteRecovery completes the recovery process
|
// CompleteRecovery completes the recovery process
|
||||||
func (s *AccountDomainService) CompleteRecovery(ctx context.Context, recoverySessionID string, newPublicKey []byte, newKeygenSessionID uuid.UUID, newShares []ShareInfo) error {
|
func (s *AccountDomainService) CompleteRecovery(ctx context.Context, recoverySessionID string, newPublicKey []byte, newKeygenSessionID uuid.UUID, newShares []ShareInfo) error {
|
||||||
// Get recovery session
|
// Get recovery session
|
||||||
recovery, err := s.recoveryRepo.GetByID(ctx, recoverySessionID)
|
recovery, err := s.recoveryRepo.GetByID(ctx, recoverySessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start keygen if still in requested state (transitions to in_progress)
|
// Start keygen if still in requested state (transitions to in_progress)
|
||||||
if recovery.Status == value_objects.RecoveryStatusRequested {
|
if recovery.Status == value_objects.RecoveryStatusRequested {
|
||||||
if err := recovery.StartKeygen(newKeygenSessionID); err != nil {
|
if err := recovery.StartKeygen(newKeygenSessionID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Complete recovery session
|
// Complete recovery session
|
||||||
if err := recovery.Complete(); err != nil {
|
if err := recovery.Complete(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get account
|
// Get account
|
||||||
account, err := s.accountRepo.GetByID(ctx, recovery.AccountID)
|
account, err := s.accountRepo.GetByID(ctx, recovery.AccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Complete account recovery
|
// Complete account recovery
|
||||||
account.CompleteRecovery(newPublicKey, newKeygenSessionID)
|
account.CompleteRecovery(newPublicKey, newKeygenSessionID)
|
||||||
|
|
||||||
// Deactivate old shares
|
// Deactivate old shares
|
||||||
if err := s.shareRepo.DeactivateByAccountID(ctx, account.ID); err != nil {
|
if err := s.shareRepo.DeactivateByAccountID(ctx, account.ID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new shares
|
// Create new shares
|
||||||
for _, shareInfo := range newShares {
|
for _, shareInfo := range newShares {
|
||||||
share := entities.NewAccountShare(
|
share := entities.NewAccountShare(
|
||||||
account.ID,
|
account.ID,
|
||||||
shareInfo.ShareType,
|
shareInfo.ShareType,
|
||||||
shareInfo.PartyID,
|
shareInfo.PartyID,
|
||||||
shareInfo.PartyIndex,
|
shareInfo.PartyIndex,
|
||||||
)
|
)
|
||||||
|
|
||||||
if shareInfo.DeviceType != nil && shareInfo.DeviceID != nil {
|
if shareInfo.DeviceType != nil && shareInfo.DeviceID != nil {
|
||||||
share.SetDeviceInfo(*shareInfo.DeviceType, *shareInfo.DeviceID)
|
share.SetDeviceInfo(*shareInfo.DeviceType, *shareInfo.DeviceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.shareRepo.Create(ctx, share); err != nil {
|
if err := s.shareRepo.Create(ctx, share); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update account
|
// Update account
|
||||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update recovery session
|
// Update recovery session
|
||||||
if err := s.recoveryRepo.Update(ctx, recovery); err != nil {
|
if err := s.recoveryRepo.Update(ctx, recovery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveShares returns active shares for an account
|
// GetActiveShares returns active shares for an account
|
||||||
func (s *AccountDomainService) GetActiveShares(ctx context.Context, accountID value_objects.AccountID) ([]*entities.AccountShare, error) {
|
func (s *AccountDomainService) GetActiveShares(ctx context.Context, accountID value_objects.AccountID) ([]*entities.AccountShare, error) {
|
||||||
return s.shareRepo.GetActiveByAccountID(ctx, accountID)
|
return s.shareRepo.GetActiveByAccountID(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CanAccountSign checks if an account has enough active shares to sign
|
// CanAccountSign checks if an account has enough active shares to sign
|
||||||
func (s *AccountDomainService) CanAccountSign(ctx context.Context, accountID value_objects.AccountID) (bool, error) {
|
func (s *AccountDomainService) CanAccountSign(ctx context.Context, accountID value_objects.AccountID) (bool, error) {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !account.CanLogin() {
|
if !account.CanLogin() {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
shares, err := s.shareRepo.GetActiveByAccountID(ctx, accountID)
|
shares, err := s.shareRepo.GetActiveByAccountID(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count active shares
|
// Count active shares
|
||||||
activeCount := 0
|
activeCount := 0
|
||||||
for _, share := range shares {
|
for _, share := range shares {
|
||||||
if share.IsActive {
|
if share.IsActive {
|
||||||
activeCount++
|
activeCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if we have enough shares for threshold
|
// Check if we have enough shares for threshold
|
||||||
return activeCount >= account.ThresholdT, nil
|
return activeCount >= account.ThresholdT, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,70 +1,70 @@
|
||||||
package value_objects
|
package value_objects
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AccountID represents a unique account identifier
|
// AccountID represents a unique account identifier
|
||||||
type AccountID struct {
|
type AccountID struct {
|
||||||
value uuid.UUID
|
value uuid.UUID
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountID creates a new AccountID
|
// NewAccountID creates a new AccountID
|
||||||
func NewAccountID() AccountID {
|
func NewAccountID() AccountID {
|
||||||
return AccountID{value: uuid.New()}
|
return AccountID{value: uuid.New()}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountIDFromString creates an AccountID from a string
|
// AccountIDFromString creates an AccountID from a string
|
||||||
func AccountIDFromString(s string) (AccountID, error) {
|
func AccountIDFromString(s string) (AccountID, error) {
|
||||||
id, err := uuid.Parse(s)
|
id, err := uuid.Parse(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return AccountID{}, err
|
return AccountID{}, err
|
||||||
}
|
}
|
||||||
return AccountID{value: id}, nil
|
return AccountID{value: id}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountIDFromUUID creates an AccountID from a UUID
|
// AccountIDFromUUID creates an AccountID from a UUID
|
||||||
func AccountIDFromUUID(id uuid.UUID) AccountID {
|
func AccountIDFromUUID(id uuid.UUID) AccountID {
|
||||||
return AccountID{value: id}
|
return AccountID{value: id}
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns the string representation
|
// String returns the string representation
|
||||||
func (id AccountID) String() string {
|
func (id AccountID) String() string {
|
||||||
return id.value.String()
|
return id.value.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UUID returns the UUID value
|
// UUID returns the UUID value
|
||||||
func (id AccountID) UUID() uuid.UUID {
|
func (id AccountID) UUID() uuid.UUID {
|
||||||
return id.value
|
return id.value
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsZero checks if the AccountID is zero
|
// IsZero checks if the AccountID is zero
|
||||||
func (id AccountID) IsZero() bool {
|
func (id AccountID) IsZero() bool {
|
||||||
return id.value == uuid.Nil
|
return id.value == uuid.Nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Equals checks if two AccountIDs are equal
|
// Equals checks if two AccountIDs are equal
|
||||||
func (id AccountID) Equals(other AccountID) bool {
|
func (id AccountID) Equals(other AccountID) bool {
|
||||||
return id.value == other.value
|
return id.value == other.value
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements json.Marshaler interface
|
// MarshalJSON implements json.Marshaler interface
|
||||||
func (id AccountID) MarshalJSON() ([]byte, error) {
|
func (id AccountID) MarshalJSON() ([]byte, error) {
|
||||||
return []byte(`"` + id.value.String() + `"`), nil
|
return []byte(`"` + id.value.String() + `"`), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalJSON implements json.Unmarshaler interface
|
// UnmarshalJSON implements json.Unmarshaler interface
|
||||||
func (id *AccountID) UnmarshalJSON(data []byte) error {
|
func (id *AccountID) UnmarshalJSON(data []byte) error {
|
||||||
// Remove quotes
|
// Remove quotes
|
||||||
str := string(data)
|
str := string(data)
|
||||||
if len(str) >= 2 && str[0] == '"' && str[len(str)-1] == '"' {
|
if len(str) >= 2 && str[0] == '"' && str[len(str)-1] == '"' {
|
||||||
str = str[1 : len(str)-1]
|
str = str[1 : len(str)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
parsed, err := uuid.Parse(str)
|
parsed, err := uuid.Parse(str)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
id.value = parsed
|
id.value = parsed
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,108 +1,108 @@
|
||||||
package value_objects
|
package value_objects
|
||||||
|
|
||||||
// AccountStatus represents the status of an account
|
// AccountStatus represents the status of an account
|
||||||
type AccountStatus string
|
type AccountStatus string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
AccountStatusActive AccountStatus = "active"
|
AccountStatusActive AccountStatus = "active"
|
||||||
AccountStatusSuspended AccountStatus = "suspended"
|
AccountStatusSuspended AccountStatus = "suspended"
|
||||||
AccountStatusLocked AccountStatus = "locked"
|
AccountStatusLocked AccountStatus = "locked"
|
||||||
AccountStatusRecovering AccountStatus = "recovering"
|
AccountStatusRecovering AccountStatus = "recovering"
|
||||||
)
|
)
|
||||||
|
|
||||||
// String returns the string representation
|
// String returns the string representation
|
||||||
func (s AccountStatus) String() string {
|
func (s AccountStatus) String() string {
|
||||||
return string(s)
|
return string(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValid checks if the status is valid
|
// IsValid checks if the status is valid
|
||||||
func (s AccountStatus) IsValid() bool {
|
func (s AccountStatus) IsValid() bool {
|
||||||
switch s {
|
switch s {
|
||||||
case AccountStatusActive, AccountStatusSuspended, AccountStatusLocked, AccountStatusRecovering:
|
case AccountStatusActive, AccountStatusSuspended, AccountStatusLocked, AccountStatusRecovering:
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CanLogin checks if the account can login with this status
|
// CanLogin checks if the account can login with this status
|
||||||
func (s AccountStatus) CanLogin() bool {
|
func (s AccountStatus) CanLogin() bool {
|
||||||
return s == AccountStatusActive
|
return s == AccountStatusActive
|
||||||
}
|
}
|
||||||
|
|
||||||
// CanInitiateRecovery checks if recovery can be initiated
|
// CanInitiateRecovery checks if recovery can be initiated
|
||||||
func (s AccountStatus) CanInitiateRecovery() bool {
|
func (s AccountStatus) CanInitiateRecovery() bool {
|
||||||
return s == AccountStatusActive || s == AccountStatusLocked
|
return s == AccountStatusActive || s == AccountStatusLocked
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShareType represents the type of key share
|
// ShareType represents the type of key share
|
||||||
type ShareType string
|
type ShareType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ShareTypeUserDevice ShareType = "user_device"
|
ShareTypeUserDevice ShareType = "user_device"
|
||||||
ShareTypeServer ShareType = "server"
|
ShareTypeServer ShareType = "server"
|
||||||
ShareTypeRecovery ShareType = "recovery"
|
ShareTypeRecovery ShareType = "recovery"
|
||||||
)
|
)
|
||||||
|
|
||||||
// String returns the string representation
|
// String returns the string representation
|
||||||
func (st ShareType) String() string {
|
func (st ShareType) String() string {
|
||||||
return string(st)
|
return string(st)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValid checks if the share type is valid
|
// IsValid checks if the share type is valid
|
||||||
func (st ShareType) IsValid() bool {
|
func (st ShareType) IsValid() bool {
|
||||||
switch st {
|
switch st {
|
||||||
case ShareTypeUserDevice, ShareTypeServer, ShareTypeRecovery:
|
case ShareTypeUserDevice, ShareTypeServer, ShareTypeRecovery:
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecoveryType represents the type of account recovery
|
// RecoveryType represents the type of account recovery
|
||||||
type RecoveryType string
|
type RecoveryType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RecoveryTypeDeviceLost RecoveryType = "device_lost"
|
RecoveryTypeDeviceLost RecoveryType = "device_lost"
|
||||||
RecoveryTypeShareRotation RecoveryType = "share_rotation"
|
RecoveryTypeShareRotation RecoveryType = "share_rotation"
|
||||||
)
|
)
|
||||||
|
|
||||||
// String returns the string representation
|
// String returns the string representation
|
||||||
func (rt RecoveryType) String() string {
|
func (rt RecoveryType) String() string {
|
||||||
return string(rt)
|
return string(rt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValid checks if the recovery type is valid
|
// IsValid checks if the recovery type is valid
|
||||||
func (rt RecoveryType) IsValid() bool {
|
func (rt RecoveryType) IsValid() bool {
|
||||||
switch rt {
|
switch rt {
|
||||||
case RecoveryTypeDeviceLost, RecoveryTypeShareRotation:
|
case RecoveryTypeDeviceLost, RecoveryTypeShareRotation:
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecoveryStatus represents the status of a recovery session
|
// RecoveryStatus represents the status of a recovery session
|
||||||
type RecoveryStatus string
|
type RecoveryStatus string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RecoveryStatusRequested RecoveryStatus = "requested"
|
RecoveryStatusRequested RecoveryStatus = "requested"
|
||||||
RecoveryStatusInProgress RecoveryStatus = "in_progress"
|
RecoveryStatusInProgress RecoveryStatus = "in_progress"
|
||||||
RecoveryStatusCompleted RecoveryStatus = "completed"
|
RecoveryStatusCompleted RecoveryStatus = "completed"
|
||||||
RecoveryStatusFailed RecoveryStatus = "failed"
|
RecoveryStatusFailed RecoveryStatus = "failed"
|
||||||
)
|
)
|
||||||
|
|
||||||
// String returns the string representation
|
// String returns the string representation
|
||||||
func (rs RecoveryStatus) String() string {
|
func (rs RecoveryStatus) String() string {
|
||||||
return string(rs)
|
return string(rs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValid checks if the recovery status is valid
|
// IsValid checks if the recovery status is valid
|
||||||
func (rs RecoveryStatus) IsValid() bool {
|
func (rs RecoveryStatus) IsValid() bool {
|
||||||
switch rs {
|
switch rs {
|
||||||
case RecoveryStatusRequested, RecoveryStatusInProgress, RecoveryStatusCompleted, RecoveryStatusFailed:
|
case RecoveryStatusRequested, RecoveryStatusInProgress, RecoveryStatusCompleted, RecoveryStatusFailed:
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,38 +1,38 @@
|
||||||
# Build stage
|
# Build stage
|
||||||
FROM golang:1.21-alpine AS builder
|
FROM golang:1.21-alpine AS builder
|
||||||
|
|
||||||
RUN apk add --no-cache git ca-certificates
|
RUN apk add --no-cache git ca-certificates
|
||||||
|
|
||||||
# Set Go proxy (can be overridden with --build-arg GOPROXY=...)
|
# Set Go proxy (can be overridden with --build-arg GOPROXY=...)
|
||||||
ARG GOPROXY=https://proxy.golang.org,direct
|
ARG GOPROXY=https://proxy.golang.org,direct
|
||||||
ENV GOPROXY=${GOPROXY}
|
ENV GOPROXY=${GOPROXY}
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
COPY go.mod go.sum ./
|
COPY go.mod go.sum ./
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
||||||
-ldflags="-w -s" \
|
-ldflags="-w -s" \
|
||||||
-o /bin/message-router \
|
-o /bin/message-router \
|
||||||
./services/message-router/cmd/server
|
./services/message-router/cmd/server
|
||||||
|
|
||||||
# Final stage
|
# Final stage
|
||||||
FROM alpine:3.18
|
FROM alpine:3.18
|
||||||
|
|
||||||
RUN apk --no-cache add ca-certificates curl
|
RUN apk --no-cache add ca-certificates curl
|
||||||
RUN adduser -D -s /bin/sh mpc
|
RUN adduser -D -s /bin/sh mpc
|
||||||
|
|
||||||
COPY --from=builder /bin/message-router /bin/message-router
|
COPY --from=builder /bin/message-router /bin/message-router
|
||||||
|
|
||||||
USER mpc
|
USER mpc
|
||||||
|
|
||||||
EXPOSE 50051 8080
|
EXPOSE 50051 8080
|
||||||
|
|
||||||
# Health check
|
# Health check
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||||
CMD curl -sf http://localhost:8080/health || exit 1
|
CMD curl -sf http://localhost:8080/health || exit 1
|
||||||
|
|
||||||
ENTRYPOINT ["/bin/message-router"]
|
ENTRYPOINT ["/bin/message-router"]
|
||||||
|
|
|
||||||
|
|
@ -1,162 +1,267 @@
|
||||||
package grpc
|
package grpc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
pb "github.com/rwadurian/mpc-system/api/grpc/router/v1"
|
pb "github.com/rwadurian/mpc-system/api/grpc/router/v1"
|
||||||
"github.com/rwadurian/mpc-system/services/message-router/adapters/output/rabbitmq"
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
||||||
"github.com/rwadurian/mpc-system/services/message-router/application/use_cases"
|
"github.com/rwadurian/mpc-system/services/message-router/adapters/output/rabbitmq"
|
||||||
"github.com/rwadurian/mpc-system/services/message-router/domain/entities"
|
"github.com/rwadurian/mpc-system/services/message-router/application/use_cases"
|
||||||
"google.golang.org/grpc/codes"
|
"github.com/rwadurian/mpc-system/services/message-router/domain"
|
||||||
"google.golang.org/grpc/status"
|
"github.com/rwadurian/mpc-system/services/message-router/domain/entities"
|
||||||
)
|
"go.uber.org/zap"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
// MessageRouterServer implements the gRPC MessageRouter service
|
"google.golang.org/grpc/status"
|
||||||
type MessageRouterServer struct {
|
)
|
||||||
pb.UnimplementedMessageRouterServer
|
|
||||||
routeMessageUC *use_cases.RouteMessageUseCase
|
// MessageRouterServer implements the gRPC MessageRouter service
|
||||||
getPendingMessagesUC *use_cases.GetPendingMessagesUseCase
|
type MessageRouterServer struct {
|
||||||
messageBroker *rabbitmq.MessageBrokerAdapter
|
pb.UnimplementedMessageRouterServer
|
||||||
}
|
routeMessageUC *use_cases.RouteMessageUseCase
|
||||||
|
getPendingMessagesUC *use_cases.GetPendingMessagesUseCase
|
||||||
// NewMessageRouterServer creates a new gRPC server
|
messageBroker *rabbitmq.MessageBrokerAdapter
|
||||||
func NewMessageRouterServer(
|
partyRegistry *domain.PartyRegistry
|
||||||
routeMessageUC *use_cases.RouteMessageUseCase,
|
eventBroadcaster *domain.SessionEventBroadcaster
|
||||||
getPendingMessagesUC *use_cases.GetPendingMessagesUseCase,
|
}
|
||||||
messageBroker *rabbitmq.MessageBrokerAdapter,
|
|
||||||
) *MessageRouterServer {
|
// NewMessageRouterServer creates a new gRPC server
|
||||||
return &MessageRouterServer{
|
func NewMessageRouterServer(
|
||||||
routeMessageUC: routeMessageUC,
|
routeMessageUC *use_cases.RouteMessageUseCase,
|
||||||
getPendingMessagesUC: getPendingMessagesUC,
|
getPendingMessagesUC *use_cases.GetPendingMessagesUseCase,
|
||||||
messageBroker: messageBroker,
|
messageBroker *rabbitmq.MessageBrokerAdapter,
|
||||||
}
|
partyRegistry *domain.PartyRegistry,
|
||||||
}
|
eventBroadcaster *domain.SessionEventBroadcaster,
|
||||||
|
) *MessageRouterServer {
|
||||||
// RouteMessage routes an MPC message
|
return &MessageRouterServer{
|
||||||
func (s *MessageRouterServer) RouteMessage(
|
routeMessageUC: routeMessageUC,
|
||||||
ctx context.Context,
|
getPendingMessagesUC: getPendingMessagesUC,
|
||||||
req *pb.RouteMessageRequest,
|
messageBroker: messageBroker,
|
||||||
) (*pb.RouteMessageResponse, error) {
|
partyRegistry: partyRegistry,
|
||||||
input := use_cases.RouteMessageInput{
|
eventBroadcaster: eventBroadcaster,
|
||||||
SessionID: req.SessionId,
|
}
|
||||||
FromParty: req.FromParty,
|
}
|
||||||
ToParties: req.ToParties,
|
|
||||||
RoundNumber: int(req.RoundNumber),
|
// RouteMessage routes an MPC message
|
||||||
MessageType: req.MessageType,
|
func (s *MessageRouterServer) RouteMessage(
|
||||||
Payload: req.Payload,
|
ctx context.Context,
|
||||||
}
|
req *pb.RouteMessageRequest,
|
||||||
|
) (*pb.RouteMessageResponse, error) {
|
||||||
output, err := s.routeMessageUC.Execute(ctx, input)
|
input := use_cases.RouteMessageInput{
|
||||||
if err != nil {
|
SessionID: req.SessionId,
|
||||||
return nil, toGRPCError(err)
|
FromParty: req.FromParty,
|
||||||
}
|
ToParties: req.ToParties,
|
||||||
|
RoundNumber: int(req.RoundNumber),
|
||||||
return &pb.RouteMessageResponse{
|
MessageType: req.MessageType,
|
||||||
Success: output.Success,
|
Payload: req.Payload,
|
||||||
MessageId: output.MessageID,
|
}
|
||||||
}, nil
|
|
||||||
}
|
output, err := s.routeMessageUC.Execute(ctx, input)
|
||||||
|
if err != nil {
|
||||||
// SubscribeMessages subscribes to messages for a party (streaming)
|
return nil, toGRPCError(err)
|
||||||
func (s *MessageRouterServer) SubscribeMessages(
|
}
|
||||||
req *pb.SubscribeMessagesRequest,
|
|
||||||
stream pb.MessageRouter_SubscribeMessagesServer,
|
return &pb.RouteMessageResponse{
|
||||||
) error {
|
Success: output.Success,
|
||||||
ctx := stream.Context()
|
MessageId: output.MessageID,
|
||||||
|
}, nil
|
||||||
// Subscribe to party messages
|
}
|
||||||
partyCh, err := s.messageBroker.SubscribeToPartyMessages(ctx, req.PartyId)
|
|
||||||
if err != nil {
|
// SubscribeMessages subscribes to messages for a party (streaming)
|
||||||
return status.Error(codes.Internal, err.Error())
|
func (s *MessageRouterServer) SubscribeMessages(
|
||||||
}
|
req *pb.SubscribeMessagesRequest,
|
||||||
|
stream pb.MessageRouter_SubscribeMessagesServer,
|
||||||
// Subscribe to session messages (broadcasts)
|
) error {
|
||||||
sessionCh, err := s.messageBroker.SubscribeToSessionMessages(ctx, req.SessionId, req.PartyId)
|
ctx := stream.Context()
|
||||||
if err != nil {
|
|
||||||
return status.Error(codes.Internal, err.Error())
|
// Subscribe to party messages
|
||||||
}
|
partyCh, err := s.messageBroker.SubscribeToPartyMessages(ctx, req.PartyId)
|
||||||
|
if err != nil {
|
||||||
// Merge channels and stream messages
|
return status.Error(codes.Internal, err.Error())
|
||||||
for {
|
}
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
// Subscribe to session messages (broadcasts)
|
||||||
return nil
|
sessionCh, err := s.messageBroker.SubscribeToSessionMessages(ctx, req.SessionId, req.PartyId)
|
||||||
case msg, ok := <-partyCh:
|
if err != nil {
|
||||||
if !ok {
|
return status.Error(codes.Internal, err.Error())
|
||||||
return nil
|
}
|
||||||
}
|
|
||||||
if err := sendMessage(stream, msg); err != nil {
|
// Merge channels and stream messages
|
||||||
return err
|
for {
|
||||||
}
|
select {
|
||||||
case msg, ok := <-sessionCh:
|
case <-ctx.Done():
|
||||||
if !ok {
|
return nil
|
||||||
return nil
|
case msg, ok := <-partyCh:
|
||||||
}
|
if !ok {
|
||||||
if err := sendMessage(stream, msg); err != nil {
|
return nil
|
||||||
return err
|
}
|
||||||
}
|
if err := sendMessage(stream, msg); err != nil {
|
||||||
}
|
return err
|
||||||
}
|
}
|
||||||
}
|
case msg, ok := <-sessionCh:
|
||||||
|
if !ok {
|
||||||
// GetPendingMessages retrieves pending messages (polling alternative)
|
return nil
|
||||||
func (s *MessageRouterServer) GetPendingMessages(
|
}
|
||||||
ctx context.Context,
|
if err := sendMessage(stream, msg); err != nil {
|
||||||
req *pb.GetPendingMessagesRequest,
|
return err
|
||||||
) (*pb.GetPendingMessagesResponse, error) {
|
}
|
||||||
input := use_cases.GetPendingMessagesInput{
|
}
|
||||||
SessionID: req.SessionId,
|
}
|
||||||
PartyID: req.PartyId,
|
}
|
||||||
AfterTimestamp: req.AfterTimestamp,
|
|
||||||
}
|
// GetPendingMessages retrieves pending messages (polling alternative)
|
||||||
|
func (s *MessageRouterServer) GetPendingMessages(
|
||||||
messages, err := s.getPendingMessagesUC.Execute(ctx, input)
|
ctx context.Context,
|
||||||
if err != nil {
|
req *pb.GetPendingMessagesRequest,
|
||||||
return nil, toGRPCError(err)
|
) (*pb.GetPendingMessagesResponse, error) {
|
||||||
}
|
input := use_cases.GetPendingMessagesInput{
|
||||||
|
SessionID: req.SessionId,
|
||||||
protoMessages := make([]*pb.MPCMessage, len(messages))
|
PartyID: req.PartyId,
|
||||||
for i, msg := range messages {
|
AfterTimestamp: req.AfterTimestamp,
|
||||||
protoMessages[i] = &pb.MPCMessage{
|
}
|
||||||
MessageId: msg.ID,
|
|
||||||
SessionId: msg.SessionID,
|
messages, err := s.getPendingMessagesUC.Execute(ctx, input)
|
||||||
FromParty: msg.FromParty,
|
if err != nil {
|
||||||
IsBroadcast: msg.IsBroadcast,
|
return nil, toGRPCError(err)
|
||||||
RoundNumber: int32(msg.RoundNumber),
|
}
|
||||||
MessageType: msg.MessageType,
|
|
||||||
Payload: msg.Payload,
|
protoMessages := make([]*pb.MPCMessage, len(messages))
|
||||||
CreatedAt: msg.CreatedAt,
|
for i, msg := range messages {
|
||||||
}
|
protoMessages[i] = &pb.MPCMessage{
|
||||||
}
|
MessageId: msg.ID,
|
||||||
|
SessionId: msg.SessionID,
|
||||||
return &pb.GetPendingMessagesResponse{
|
FromParty: msg.FromParty,
|
||||||
Messages: protoMessages,
|
IsBroadcast: msg.IsBroadcast,
|
||||||
}, nil
|
RoundNumber: int32(msg.RoundNumber),
|
||||||
}
|
MessageType: msg.MessageType,
|
||||||
|
Payload: msg.Payload,
|
||||||
func sendMessage(stream pb.MessageRouter_SubscribeMessagesServer, msg *entities.MessageDTO) error {
|
CreatedAt: msg.CreatedAt,
|
||||||
protoMsg := &pb.MPCMessage{
|
}
|
||||||
MessageId: msg.ID,
|
}
|
||||||
SessionId: msg.SessionID,
|
|
||||||
FromParty: msg.FromParty,
|
return &pb.GetPendingMessagesResponse{
|
||||||
IsBroadcast: msg.IsBroadcast,
|
Messages: protoMessages,
|
||||||
RoundNumber: int32(msg.RoundNumber),
|
}, nil
|
||||||
MessageType: msg.MessageType,
|
}
|
||||||
Payload: msg.Payload,
|
|
||||||
CreatedAt: msg.CreatedAt,
|
// RegisterParty registers a party with the message router
|
||||||
}
|
func (s *MessageRouterServer) RegisterParty(
|
||||||
return stream.Send(protoMsg)
|
ctx context.Context,
|
||||||
}
|
req *pb.RegisterPartyRequest,
|
||||||
|
) (*pb.RegisterPartyResponse, error) {
|
||||||
func toGRPCError(err error) error {
|
if req.PartyId == "" {
|
||||||
switch err {
|
return nil, status.Error(codes.InvalidArgument, "party_id is required")
|
||||||
case use_cases.ErrInvalidSessionID:
|
}
|
||||||
return status.Error(codes.InvalidArgument, err.Error())
|
|
||||||
case use_cases.ErrInvalidPartyID:
|
// Register party
|
||||||
return status.Error(codes.InvalidArgument, err.Error())
|
party := s.partyRegistry.Register(req.PartyId, req.PartyRole, req.Version)
|
||||||
case use_cases.ErrEmptyPayload:
|
|
||||||
return status.Error(codes.InvalidArgument, err.Error())
|
logger.Info("Party registered",
|
||||||
default:
|
zap.String("party_id", req.PartyId),
|
||||||
return status.Error(codes.Internal, err.Error())
|
zap.String("role", req.PartyRole),
|
||||||
}
|
zap.String("version", req.Version))
|
||||||
}
|
|
||||||
|
return &pb.RegisterPartyResponse{
|
||||||
|
Success: true,
|
||||||
|
Message: "Party registered successfully",
|
||||||
|
RegisteredAt: party.RegisteredAt.UnixMilli(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubscribeSessionEvents subscribes to session lifecycle events (streaming)
|
||||||
|
func (s *MessageRouterServer) SubscribeSessionEvents(
|
||||||
|
req *pb.SubscribeSessionEventsRequest,
|
||||||
|
stream pb.MessageRouter_SubscribeSessionEventsServer,
|
||||||
|
) error {
|
||||||
|
ctx := stream.Context()
|
||||||
|
|
||||||
|
if req.PartyId == "" {
|
||||||
|
return status.Error(codes.InvalidArgument, "party_id is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if party is registered
|
||||||
|
if _, exists := s.partyRegistry.Get(req.PartyId); !exists {
|
||||||
|
return status.Error(codes.FailedPrecondition, "party not registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Party subscribed to session events",
|
||||||
|
zap.String("party_id", req.PartyId))
|
||||||
|
|
||||||
|
// Subscribe to events
|
||||||
|
eventCh := s.eventBroadcaster.Subscribe(req.PartyId)
|
||||||
|
defer s.eventBroadcaster.Unsubscribe(req.PartyId)
|
||||||
|
|
||||||
|
// Stream events
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
logger.Info("Party unsubscribed from session events",
|
||||||
|
zap.String("party_id", req.PartyId))
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case event, ok := <-eventCh:
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send event to party
|
||||||
|
if err := stream.Send(event); err != nil {
|
||||||
|
logger.Error("Failed to send session event",
|
||||||
|
zap.String("party_id", req.PartyId),
|
||||||
|
zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Sent session event to party",
|
||||||
|
zap.String("party_id", req.PartyId),
|
||||||
|
zap.String("event_type", event.EventType),
|
||||||
|
zap.String("session_id", event.SessionId))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishSessionEvent publishes a session event to subscribed parties
|
||||||
|
// This is called by Session Coordinator
|
||||||
|
func (s *MessageRouterServer) PublishSessionEvent(event *pb.SessionEvent) {
|
||||||
|
// If selected_parties is specified, send only to those parties
|
||||||
|
if len(event.SelectedParties) > 0 {
|
||||||
|
s.eventBroadcaster.BroadcastToParties(event, event.SelectedParties)
|
||||||
|
logger.Info("Published session event to selected parties",
|
||||||
|
zap.String("event_type", event.EventType),
|
||||||
|
zap.String("session_id", event.SessionId),
|
||||||
|
zap.Int("party_count", len(event.SelectedParties)))
|
||||||
|
} else {
|
||||||
|
// Broadcast to all subscribers
|
||||||
|
s.eventBroadcaster.Broadcast(event)
|
||||||
|
logger.Info("Broadcast session event to all parties",
|
||||||
|
zap.String("event_type", event.EventType),
|
||||||
|
zap.String("session_id", event.SessionId),
|
||||||
|
zap.Int("subscriber_count", s.eventBroadcaster.SubscriberCount()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendMessage(stream pb.MessageRouter_SubscribeMessagesServer, msg *entities.MessageDTO) error {
|
||||||
|
protoMsg := &pb.MPCMessage{
|
||||||
|
MessageId: msg.ID,
|
||||||
|
SessionId: msg.SessionID,
|
||||||
|
FromParty: msg.FromParty,
|
||||||
|
IsBroadcast: msg.IsBroadcast,
|
||||||
|
RoundNumber: int32(msg.RoundNumber),
|
||||||
|
MessageType: msg.MessageType,
|
||||||
|
Payload: msg.Payload,
|
||||||
|
CreatedAt: msg.CreatedAt,
|
||||||
|
}
|
||||||
|
return stream.Send(protoMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func toGRPCError(err error) error {
|
||||||
|
switch err {
|
||||||
|
case use_cases.ErrInvalidSessionID:
|
||||||
|
return status.Error(codes.InvalidArgument, err.Error())
|
||||||
|
case use_cases.ErrInvalidPartyID:
|
||||||
|
return status.Error(codes.InvalidArgument, err.Error())
|
||||||
|
case use_cases.ErrEmptyPayload:
|
||||||
|
return status.Error(codes.InvalidArgument, err.Error())
|
||||||
|
default:
|
||||||
|
return status.Error(codes.Internal, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,169 +1,169 @@
|
||||||
package postgres
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"github.com/rwadurian/mpc-system/services/message-router/domain/entities"
|
"github.com/rwadurian/mpc-system/services/message-router/domain/entities"
|
||||||
"github.com/rwadurian/mpc-system/services/message-router/domain/repositories"
|
"github.com/rwadurian/mpc-system/services/message-router/domain/repositories"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MessagePostgresRepo implements MessageRepository for PostgreSQL
|
// MessagePostgresRepo implements MessageRepository for PostgreSQL
|
||||||
type MessagePostgresRepo struct {
|
type MessagePostgresRepo struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMessagePostgresRepo creates a new PostgreSQL message repository
|
// NewMessagePostgresRepo creates a new PostgreSQL message repository
|
||||||
func NewMessagePostgresRepo(db *sql.DB) *MessagePostgresRepo {
|
func NewMessagePostgresRepo(db *sql.DB) *MessagePostgresRepo {
|
||||||
return &MessagePostgresRepo{db: db}
|
return &MessagePostgresRepo{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save persists a new message
|
// Save persists a new message
|
||||||
func (r *MessagePostgresRepo) Save(ctx context.Context, msg *entities.MPCMessage) error {
|
func (r *MessagePostgresRepo) Save(ctx context.Context, msg *entities.MPCMessage) error {
|
||||||
_, err := r.db.ExecContext(ctx, `
|
_, err := r.db.ExecContext(ctx, `
|
||||||
INSERT INTO mpc_messages (
|
INSERT INTO mpc_messages (
|
||||||
id, session_id, from_party, to_parties, round_number, message_type, payload, created_at
|
id, session_id, from_party, to_parties, round_number, message_type, payload, created_at
|
||||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
`,
|
`,
|
||||||
msg.ID,
|
msg.ID,
|
||||||
msg.SessionID,
|
msg.SessionID,
|
||||||
msg.FromParty,
|
msg.FromParty,
|
||||||
pq.Array(msg.ToParties),
|
pq.Array(msg.ToParties),
|
||||||
msg.RoundNumber,
|
msg.RoundNumber,
|
||||||
msg.MessageType,
|
msg.MessageType,
|
||||||
msg.Payload,
|
msg.Payload,
|
||||||
msg.CreatedAt,
|
msg.CreatedAt,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByID retrieves a message by ID
|
// GetByID retrieves a message by ID
|
||||||
func (r *MessagePostgresRepo) GetByID(ctx context.Context, id uuid.UUID) (*entities.MPCMessage, error) {
|
func (r *MessagePostgresRepo) GetByID(ctx context.Context, id uuid.UUID) (*entities.MPCMessage, error) {
|
||||||
var msg entities.MPCMessage
|
var msg entities.MPCMessage
|
||||||
var toParties []string
|
var toParties []string
|
||||||
|
|
||||||
err := r.db.QueryRowContext(ctx, `
|
err := r.db.QueryRowContext(ctx, `
|
||||||
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
||||||
FROM mpc_messages WHERE id = $1
|
FROM mpc_messages WHERE id = $1
|
||||||
`, id).Scan(
|
`, id).Scan(
|
||||||
&msg.ID,
|
&msg.ID,
|
||||||
&msg.SessionID,
|
&msg.SessionID,
|
||||||
&msg.FromParty,
|
&msg.FromParty,
|
||||||
pq.Array(&toParties),
|
pq.Array(&toParties),
|
||||||
&msg.RoundNumber,
|
&msg.RoundNumber,
|
||||||
&msg.MessageType,
|
&msg.MessageType,
|
||||||
&msg.Payload,
|
&msg.Payload,
|
||||||
&msg.CreatedAt,
|
&msg.CreatedAt,
|
||||||
&msg.DeliveredAt,
|
&msg.DeliveredAt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
msg.ToParties = toParties
|
msg.ToParties = toParties
|
||||||
return &msg, nil
|
return &msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPendingMessages retrieves pending messages for a party
|
// GetPendingMessages retrieves pending messages for a party
|
||||||
func (r *MessagePostgresRepo) GetPendingMessages(
|
func (r *MessagePostgresRepo) GetPendingMessages(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID uuid.UUID,
|
sessionID uuid.UUID,
|
||||||
partyID string,
|
partyID string,
|
||||||
afterTime time.Time,
|
afterTime time.Time,
|
||||||
) ([]*entities.MPCMessage, error) {
|
) ([]*entities.MPCMessage, error) {
|
||||||
rows, err := r.db.QueryContext(ctx, `
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
||||||
FROM mpc_messages
|
FROM mpc_messages
|
||||||
WHERE session_id = $1
|
WHERE session_id = $1
|
||||||
AND created_at > $2
|
AND created_at > $2
|
||||||
AND from_party != $3
|
AND from_party != $3
|
||||||
AND (to_parties IS NULL OR cardinality(to_parties) = 0 OR $3 = ANY(to_parties))
|
AND (to_parties IS NULL OR cardinality(to_parties) = 0 OR $3 = ANY(to_parties))
|
||||||
ORDER BY round_number ASC, created_at ASC
|
ORDER BY round_number ASC, created_at ASC
|
||||||
`, sessionID, afterTime, partyID)
|
`, sessionID, afterTime, partyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
return r.scanMessages(rows)
|
return r.scanMessages(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMessagesByRound retrieves messages for a specific round
|
// GetMessagesByRound retrieves messages for a specific round
|
||||||
func (r *MessagePostgresRepo) GetMessagesByRound(
|
func (r *MessagePostgresRepo) GetMessagesByRound(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID uuid.UUID,
|
sessionID uuid.UUID,
|
||||||
roundNumber int,
|
roundNumber int,
|
||||||
) ([]*entities.MPCMessage, error) {
|
) ([]*entities.MPCMessage, error) {
|
||||||
rows, err := r.db.QueryContext(ctx, `
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
||||||
FROM mpc_messages
|
FROM mpc_messages
|
||||||
WHERE session_id = $1 AND round_number = $2
|
WHERE session_id = $1 AND round_number = $2
|
||||||
ORDER BY created_at ASC
|
ORDER BY created_at ASC
|
||||||
`, sessionID, roundNumber)
|
`, sessionID, roundNumber)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
return r.scanMessages(rows)
|
return r.scanMessages(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkDelivered marks a message as delivered
|
// MarkDelivered marks a message as delivered
|
||||||
func (r *MessagePostgresRepo) MarkDelivered(ctx context.Context, messageID uuid.UUID) error {
|
func (r *MessagePostgresRepo) MarkDelivered(ctx context.Context, messageID uuid.UUID) error {
|
||||||
_, err := r.db.ExecContext(ctx, `
|
_, err := r.db.ExecContext(ctx, `
|
||||||
UPDATE mpc_messages SET delivered_at = NOW() WHERE id = $1
|
UPDATE mpc_messages SET delivered_at = NOW() WHERE id = $1
|
||||||
`, messageID)
|
`, messageID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteBySession deletes all messages for a session
|
// DeleteBySession deletes all messages for a session
|
||||||
func (r *MessagePostgresRepo) DeleteBySession(ctx context.Context, sessionID uuid.UUID) error {
|
func (r *MessagePostgresRepo) DeleteBySession(ctx context.Context, sessionID uuid.UUID) error {
|
||||||
_, err := r.db.ExecContext(ctx, `DELETE FROM mpc_messages WHERE session_id = $1`, sessionID)
|
_, err := r.db.ExecContext(ctx, `DELETE FROM mpc_messages WHERE session_id = $1`, sessionID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteOlderThan deletes messages older than a specific time
|
// DeleteOlderThan deletes messages older than a specific time
|
||||||
func (r *MessagePostgresRepo) DeleteOlderThan(ctx context.Context, before time.Time) (int64, error) {
|
func (r *MessagePostgresRepo) DeleteOlderThan(ctx context.Context, before time.Time) (int64, error) {
|
||||||
result, err := r.db.ExecContext(ctx, `DELETE FROM mpc_messages WHERE created_at < $1`, before)
|
result, err := r.db.ExecContext(ctx, `DELETE FROM mpc_messages WHERE created_at < $1`, before)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return result.RowsAffected()
|
return result.RowsAffected()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *MessagePostgresRepo) scanMessages(rows *sql.Rows) ([]*entities.MPCMessage, error) {
|
func (r *MessagePostgresRepo) scanMessages(rows *sql.Rows) ([]*entities.MPCMessage, error) {
|
||||||
var messages []*entities.MPCMessage
|
var messages []*entities.MPCMessage
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var msg entities.MPCMessage
|
var msg entities.MPCMessage
|
||||||
var toParties []string
|
var toParties []string
|
||||||
|
|
||||||
err := rows.Scan(
|
err := rows.Scan(
|
||||||
&msg.ID,
|
&msg.ID,
|
||||||
&msg.SessionID,
|
&msg.SessionID,
|
||||||
&msg.FromParty,
|
&msg.FromParty,
|
||||||
pq.Array(&toParties),
|
pq.Array(&toParties),
|
||||||
&msg.RoundNumber,
|
&msg.RoundNumber,
|
||||||
&msg.MessageType,
|
&msg.MessageType,
|
||||||
&msg.Payload,
|
&msg.Payload,
|
||||||
&msg.CreatedAt,
|
&msg.CreatedAt,
|
||||||
&msg.DeliveredAt,
|
&msg.DeliveredAt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
msg.ToParties = toParties
|
msg.ToParties = toParties
|
||||||
messages = append(messages, &msg)
|
messages = append(messages, &msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return messages, rows.Err()
|
return messages, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure interface compliance
|
// Ensure interface compliance
|
||||||
var _ repositories.MessageRepository = (*MessagePostgresRepo)(nil)
|
var _ repositories.MessageRepository = (*MessagePostgresRepo)(nil)
|
||||||
|
|
|
||||||
|
|
@ -1,388 +1,388 @@
|
||||||
package rabbitmq
|
package rabbitmq
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
amqp "github.com/rabbitmq/amqp091-go"
|
amqp "github.com/rabbitmq/amqp091-go"
|
||||||
"github.com/rwadurian/mpc-system/pkg/logger"
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
||||||
"github.com/rwadurian/mpc-system/services/message-router/application/use_cases"
|
"github.com/rwadurian/mpc-system/services/message-router/application/use_cases"
|
||||||
"github.com/rwadurian/mpc-system/services/message-router/domain/entities"
|
"github.com/rwadurian/mpc-system/services/message-router/domain/entities"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MessageBrokerAdapter implements MessageBroker using RabbitMQ
|
// MessageBrokerAdapter implements MessageBroker using RabbitMQ
|
||||||
type MessageBrokerAdapter struct {
|
type MessageBrokerAdapter struct {
|
||||||
conn *amqp.Connection
|
conn *amqp.Connection
|
||||||
channel *amqp.Channel
|
channel *amqp.Channel
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMessageBrokerAdapter creates a new RabbitMQ message broker
|
// NewMessageBrokerAdapter creates a new RabbitMQ message broker
|
||||||
func NewMessageBrokerAdapter(conn *amqp.Connection) (*MessageBrokerAdapter, error) {
|
func NewMessageBrokerAdapter(conn *amqp.Connection) (*MessageBrokerAdapter, error) {
|
||||||
channel, err := conn.Channel()
|
channel, err := conn.Channel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create channel: %w", err)
|
return nil, fmt.Errorf("failed to create channel: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Declare exchange for party messages
|
// Declare exchange for party messages
|
||||||
err = channel.ExchangeDeclare(
|
err = channel.ExchangeDeclare(
|
||||||
"mpc.messages", // name
|
"mpc.messages", // name
|
||||||
"direct", // type
|
"direct", // type
|
||||||
true, // durable
|
true, // durable
|
||||||
false, // auto-deleted
|
false, // auto-deleted
|
||||||
false, // internal
|
false, // internal
|
||||||
false, // no-wait
|
false, // no-wait
|
||||||
nil, // arguments
|
nil, // arguments
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to declare exchange: %w", err)
|
return nil, fmt.Errorf("failed to declare exchange: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Declare exchange for session broadcasts
|
// Declare exchange for session broadcasts
|
||||||
err = channel.ExchangeDeclare(
|
err = channel.ExchangeDeclare(
|
||||||
"mpc.session.broadcast", // name
|
"mpc.session.broadcast", // name
|
||||||
"fanout", // type
|
"fanout", // type
|
||||||
true, // durable
|
true, // durable
|
||||||
false, // auto-deleted
|
false, // auto-deleted
|
||||||
false, // internal
|
false, // internal
|
||||||
false, // no-wait
|
false, // no-wait
|
||||||
nil, // arguments
|
nil, // arguments
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to declare broadcast exchange: %w", err)
|
return nil, fmt.Errorf("failed to declare broadcast exchange: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &MessageBrokerAdapter{
|
return &MessageBrokerAdapter{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
channel: channel,
|
channel: channel,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PublishToParty publishes a message to a specific party
|
// PublishToParty publishes a message to a specific party
|
||||||
func (a *MessageBrokerAdapter) PublishToParty(ctx context.Context, partyID string, message *entities.MessageDTO) error {
|
func (a *MessageBrokerAdapter) PublishToParty(ctx context.Context, partyID string, message *entities.MessageDTO) error {
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
defer a.mu.Unlock()
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
// Ensure queue exists for the party
|
// Ensure queue exists for the party
|
||||||
queueName := fmt.Sprintf("mpc.party.%s", partyID)
|
queueName := fmt.Sprintf("mpc.party.%s", partyID)
|
||||||
_, err := a.channel.QueueDeclare(
|
_, err := a.channel.QueueDeclare(
|
||||||
queueName, // name
|
queueName, // name
|
||||||
true, // durable
|
true, // durable
|
||||||
false, // delete when unused
|
false, // delete when unused
|
||||||
false, // exclusive
|
false, // exclusive
|
||||||
false, // no-wait
|
false, // no-wait
|
||||||
nil, // arguments
|
nil, // arguments
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to declare queue: %w", err)
|
return fmt.Errorf("failed to declare queue: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bind queue to exchange
|
// Bind queue to exchange
|
||||||
err = a.channel.QueueBind(
|
err = a.channel.QueueBind(
|
||||||
queueName, // queue name
|
queueName, // queue name
|
||||||
partyID, // routing key
|
partyID, // routing key
|
||||||
"mpc.messages", // exchange
|
"mpc.messages", // exchange
|
||||||
false, // no-wait
|
false, // no-wait
|
||||||
nil, // arguments
|
nil, // arguments
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to bind queue: %w", err)
|
return fmt.Errorf("failed to bind queue: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := json.Marshal(message)
|
body, err := json.Marshal(message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to marshal message: %w", err)
|
return fmt.Errorf("failed to marshal message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.channel.PublishWithContext(
|
err = a.channel.PublishWithContext(
|
||||||
ctx,
|
ctx,
|
||||||
"mpc.messages", // exchange
|
"mpc.messages", // exchange
|
||||||
partyID, // routing key
|
partyID, // routing key
|
||||||
false, // mandatory
|
false, // mandatory
|
||||||
false, // immediate
|
false, // immediate
|
||||||
amqp.Publishing{
|
amqp.Publishing{
|
||||||
ContentType: "application/json",
|
ContentType: "application/json",
|
||||||
DeliveryMode: amqp.Persistent,
|
DeliveryMode: amqp.Persistent,
|
||||||
Body: body,
|
Body: body,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to publish message: %w", err)
|
return fmt.Errorf("failed to publish message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("published message to party",
|
logger.Debug("published message to party",
|
||||||
zap.String("party_id", partyID),
|
zap.String("party_id", partyID),
|
||||||
zap.String("message_id", message.ID))
|
zap.String("message_id", message.ID))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PublishToSession publishes a message to all parties in a session (except sender)
|
// PublishToSession publishes a message to all parties in a session (except sender)
|
||||||
func (a *MessageBrokerAdapter) PublishToSession(
|
func (a *MessageBrokerAdapter) PublishToSession(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID string,
|
sessionID string,
|
||||||
excludeParty string,
|
excludeParty string,
|
||||||
message *entities.MessageDTO,
|
message *entities.MessageDTO,
|
||||||
) error {
|
) error {
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
defer a.mu.Unlock()
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
// Use session-specific exchange
|
// Use session-specific exchange
|
||||||
exchangeName := fmt.Sprintf("mpc.session.%s", sessionID)
|
exchangeName := fmt.Sprintf("mpc.session.%s", sessionID)
|
||||||
|
|
||||||
// Declare session-specific fanout exchange
|
// Declare session-specific fanout exchange
|
||||||
err := a.channel.ExchangeDeclare(
|
err := a.channel.ExchangeDeclare(
|
||||||
exchangeName, // name
|
exchangeName, // name
|
||||||
"fanout", // type
|
"fanout", // type
|
||||||
false, // durable (temporary for session)
|
false, // durable (temporary for session)
|
||||||
true, // auto-delete when unused
|
true, // auto-delete when unused
|
||||||
false, // internal
|
false, // internal
|
||||||
false, // no-wait
|
false, // no-wait
|
||||||
nil, // arguments
|
nil, // arguments
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to declare session exchange: %w", err)
|
return fmt.Errorf("failed to declare session exchange: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := json.Marshal(message)
|
body, err := json.Marshal(message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to marshal message: %w", err)
|
return fmt.Errorf("failed to marshal message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.channel.PublishWithContext(
|
err = a.channel.PublishWithContext(
|
||||||
ctx,
|
ctx,
|
||||||
exchangeName, // exchange
|
exchangeName, // exchange
|
||||||
"", // routing key (ignored for fanout)
|
"", // routing key (ignored for fanout)
|
||||||
false, // mandatory
|
false, // mandatory
|
||||||
false, // immediate
|
false, // immediate
|
||||||
amqp.Publishing{
|
amqp.Publishing{
|
||||||
ContentType: "application/json",
|
ContentType: "application/json",
|
||||||
DeliveryMode: amqp.Persistent,
|
DeliveryMode: amqp.Persistent,
|
||||||
Body: body,
|
Body: body,
|
||||||
Headers: amqp.Table{
|
Headers: amqp.Table{
|
||||||
"exclude_party": excludeParty,
|
"exclude_party": excludeParty,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to publish broadcast: %w", err)
|
return fmt.Errorf("failed to publish broadcast: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("broadcast message to session",
|
logger.Debug("broadcast message to session",
|
||||||
zap.String("session_id", sessionID),
|
zap.String("session_id", sessionID),
|
||||||
zap.String("message_id", message.ID),
|
zap.String("message_id", message.ID),
|
||||||
zap.String("exclude_party", excludeParty))
|
zap.String("exclude_party", excludeParty))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SubscribeToPartyMessages subscribes to messages for a specific party
|
// SubscribeToPartyMessages subscribes to messages for a specific party
|
||||||
func (a *MessageBrokerAdapter) SubscribeToPartyMessages(
|
func (a *MessageBrokerAdapter) SubscribeToPartyMessages(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
partyID string,
|
partyID string,
|
||||||
) (<-chan *entities.MessageDTO, error) {
|
) (<-chan *entities.MessageDTO, error) {
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
defer a.mu.Unlock()
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
queueName := fmt.Sprintf("mpc.party.%s", partyID)
|
queueName := fmt.Sprintf("mpc.party.%s", partyID)
|
||||||
|
|
||||||
// Ensure queue exists
|
// Ensure queue exists
|
||||||
_, err := a.channel.QueueDeclare(
|
_, err := a.channel.QueueDeclare(
|
||||||
queueName, // name
|
queueName, // name
|
||||||
true, // durable
|
true, // durable
|
||||||
false, // delete when unused
|
false, // delete when unused
|
||||||
false, // exclusive
|
false, // exclusive
|
||||||
false, // no-wait
|
false, // no-wait
|
||||||
nil, // arguments
|
nil, // arguments
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to declare queue: %w", err)
|
return nil, fmt.Errorf("failed to declare queue: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bind queue to exchange
|
// Bind queue to exchange
|
||||||
err = a.channel.QueueBind(
|
err = a.channel.QueueBind(
|
||||||
queueName, // queue name
|
queueName, // queue name
|
||||||
partyID, // routing key
|
partyID, // routing key
|
||||||
"mpc.messages", // exchange
|
"mpc.messages", // exchange
|
||||||
false, // no-wait
|
false, // no-wait
|
||||||
nil, // arguments
|
nil, // arguments
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to bind queue: %w", err)
|
return nil, fmt.Errorf("failed to bind queue: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start consuming
|
// Start consuming
|
||||||
msgs, err := a.channel.Consume(
|
msgs, err := a.channel.Consume(
|
||||||
queueName, // queue
|
queueName, // queue
|
||||||
"", // consumer
|
"", // consumer
|
||||||
false, // auto-ack (we'll ack manually)
|
false, // auto-ack (we'll ack manually)
|
||||||
false, // exclusive
|
false, // exclusive
|
||||||
false, // no-local
|
false, // no-local
|
||||||
false, // no-wait
|
false, // no-wait
|
||||||
nil, // args
|
nil, // args
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to register consumer: %w", err)
|
return nil, fmt.Errorf("failed to register consumer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create output channel
|
// Create output channel
|
||||||
out := make(chan *entities.MessageDTO, 100)
|
out := make(chan *entities.MessageDTO, 100)
|
||||||
|
|
||||||
// Start goroutine to forward messages
|
// Start goroutine to forward messages
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case msg, ok := <-msgs:
|
case msg, ok := <-msgs:
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var dto entities.MessageDTO
|
var dto entities.MessageDTO
|
||||||
if err := json.Unmarshal(msg.Body, &dto); err != nil {
|
if err := json.Unmarshal(msg.Body, &dto); err != nil {
|
||||||
logger.Error("failed to unmarshal message", zap.Error(err))
|
logger.Error("failed to unmarshal message", zap.Error(err))
|
||||||
msg.Nack(false, false)
|
msg.Nack(false, false)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case out <- &dto:
|
case out <- &dto:
|
||||||
msg.Ack(false)
|
msg.Ack(false)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
msg.Nack(false, true) // Requeue
|
msg.Nack(false, true) // Requeue
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SubscribeToSessionMessages subscribes to all messages in a session
|
// SubscribeToSessionMessages subscribes to all messages in a session
|
||||||
func (a *MessageBrokerAdapter) SubscribeToSessionMessages(
|
func (a *MessageBrokerAdapter) SubscribeToSessionMessages(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID string,
|
sessionID string,
|
||||||
partyID string,
|
partyID string,
|
||||||
) (<-chan *entities.MessageDTO, error) {
|
) (<-chan *entities.MessageDTO, error) {
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
defer a.mu.Unlock()
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
exchangeName := fmt.Sprintf("mpc.session.%s", sessionID)
|
exchangeName := fmt.Sprintf("mpc.session.%s", sessionID)
|
||||||
queueName := fmt.Sprintf("mpc.session.%s.%s", sessionID, partyID)
|
queueName := fmt.Sprintf("mpc.session.%s.%s", sessionID, partyID)
|
||||||
|
|
||||||
// Declare session-specific fanout exchange
|
// Declare session-specific fanout exchange
|
||||||
err := a.channel.ExchangeDeclare(
|
err := a.channel.ExchangeDeclare(
|
||||||
exchangeName, // name
|
exchangeName, // name
|
||||||
"fanout", // type
|
"fanout", // type
|
||||||
false, // durable
|
false, // durable
|
||||||
true, // auto-delete
|
true, // auto-delete
|
||||||
false, // internal
|
false, // internal
|
||||||
false, // no-wait
|
false, // no-wait
|
||||||
nil, // arguments
|
nil, // arguments
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to declare session exchange: %w", err)
|
return nil, fmt.Errorf("failed to declare session exchange: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Declare temporary queue for this subscriber
|
// Declare temporary queue for this subscriber
|
||||||
_, err = a.channel.QueueDeclare(
|
_, err = a.channel.QueueDeclare(
|
||||||
queueName, // name
|
queueName, // name
|
||||||
false, // durable
|
false, // durable
|
||||||
true, // delete when unused
|
true, // delete when unused
|
||||||
true, // exclusive
|
true, // exclusive
|
||||||
false, // no-wait
|
false, // no-wait
|
||||||
nil, // arguments
|
nil, // arguments
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to declare queue: %w", err)
|
return nil, fmt.Errorf("failed to declare queue: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bind queue to session exchange
|
// Bind queue to session exchange
|
||||||
err = a.channel.QueueBind(
|
err = a.channel.QueueBind(
|
||||||
queueName, // queue name
|
queueName, // queue name
|
||||||
"", // routing key (ignored for fanout)
|
"", // routing key (ignored for fanout)
|
||||||
exchangeName, // exchange
|
exchangeName, // exchange
|
||||||
false, // no-wait
|
false, // no-wait
|
||||||
nil, // arguments
|
nil, // arguments
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to bind queue: %w", err)
|
return nil, fmt.Errorf("failed to bind queue: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start consuming
|
// Start consuming
|
||||||
msgs, err := a.channel.Consume(
|
msgs, err := a.channel.Consume(
|
||||||
queueName, // queue
|
queueName, // queue
|
||||||
"", // consumer
|
"", // consumer
|
||||||
false, // auto-ack
|
false, // auto-ack
|
||||||
true, // exclusive
|
true, // exclusive
|
||||||
false, // no-local
|
false, // no-local
|
||||||
false, // no-wait
|
false, // no-wait
|
||||||
nil, // args
|
nil, // args
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to register consumer: %w", err)
|
return nil, fmt.Errorf("failed to register consumer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create output channel
|
// Create output channel
|
||||||
out := make(chan *entities.MessageDTO, 100)
|
out := make(chan *entities.MessageDTO, 100)
|
||||||
|
|
||||||
// Start goroutine to forward messages
|
// Start goroutine to forward messages
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case msg, ok := <-msgs:
|
case msg, ok := <-msgs:
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this message should be excluded for this party
|
// Check if this message should be excluded for this party
|
||||||
if excludeParty, ok := msg.Headers["exclude_party"].(string); ok {
|
if excludeParty, ok := msg.Headers["exclude_party"].(string); ok {
|
||||||
if excludeParty == partyID {
|
if excludeParty == partyID {
|
||||||
msg.Ack(false)
|
msg.Ack(false)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var dto entities.MessageDTO
|
var dto entities.MessageDTO
|
||||||
if err := json.Unmarshal(msg.Body, &dto); err != nil {
|
if err := json.Unmarshal(msg.Body, &dto); err != nil {
|
||||||
logger.Error("failed to unmarshal message", zap.Error(err))
|
logger.Error("failed to unmarshal message", zap.Error(err))
|
||||||
msg.Nack(false, false)
|
msg.Nack(false, false)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case out <- &dto:
|
case out <- &dto:
|
||||||
msg.Ack(false)
|
msg.Ack(false)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
msg.Nack(false, true)
|
msg.Nack(false, true)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the connection
|
// Close closes the connection
|
||||||
func (a *MessageBrokerAdapter) Close() error {
|
func (a *MessageBrokerAdapter) Close() error {
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
defer a.mu.Unlock()
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
if a.channel != nil {
|
if a.channel != nil {
|
||||||
return a.channel.Close()
|
return a.channel.Close()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure interface compliance
|
// Ensure interface compliance
|
||||||
var _ use_cases.MessageBroker = (*MessageBrokerAdapter)(nil)
|
var _ use_cases.MessageBroker = (*MessageBrokerAdapter)(nil)
|
||||||
|
|
|
||||||
|
|
@ -1,170 +1,170 @@
|
||||||
package use_cases
|
package use_cases
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rwadurian/mpc-system/pkg/logger"
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
||||||
"github.com/rwadurian/mpc-system/services/message-router/domain/entities"
|
"github.com/rwadurian/mpc-system/services/message-router/domain/entities"
|
||||||
"github.com/rwadurian/mpc-system/services/message-router/domain/repositories"
|
"github.com/rwadurian/mpc-system/services/message-router/domain/repositories"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidSessionID = errors.New("invalid session ID")
|
ErrInvalidSessionID = errors.New("invalid session ID")
|
||||||
ErrInvalidPartyID = errors.New("invalid party ID")
|
ErrInvalidPartyID = errors.New("invalid party ID")
|
||||||
ErrEmptyPayload = errors.New("empty payload")
|
ErrEmptyPayload = errors.New("empty payload")
|
||||||
)
|
)
|
||||||
|
|
||||||
// RouteMessageInput contains input for routing a message
|
// RouteMessageInput contains input for routing a message
|
||||||
type RouteMessageInput struct {
|
type RouteMessageInput struct {
|
||||||
SessionID string
|
SessionID string
|
||||||
FromParty string
|
FromParty string
|
||||||
ToParties []string // nil/empty means broadcast
|
ToParties []string // nil/empty means broadcast
|
||||||
RoundNumber int
|
RoundNumber int
|
||||||
MessageType string
|
MessageType string
|
||||||
Payload []byte
|
Payload []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// RouteMessageOutput contains output from routing a message
|
// RouteMessageOutput contains output from routing a message
|
||||||
type RouteMessageOutput struct {
|
type RouteMessageOutput struct {
|
||||||
MessageID string
|
MessageID string
|
||||||
Success bool
|
Success bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// MessageBroker defines the interface for message delivery
|
// MessageBroker defines the interface for message delivery
|
||||||
type MessageBroker interface {
|
type MessageBroker interface {
|
||||||
// PublishToParty publishes a message to a specific party
|
// PublishToParty publishes a message to a specific party
|
||||||
PublishToParty(ctx context.Context, partyID string, message *entities.MessageDTO) error
|
PublishToParty(ctx context.Context, partyID string, message *entities.MessageDTO) error
|
||||||
// PublishToSession publishes a message to all parties in a session (except sender)
|
// PublishToSession publishes a message to all parties in a session (except sender)
|
||||||
PublishToSession(ctx context.Context, sessionID string, excludeParty string, message *entities.MessageDTO) error
|
PublishToSession(ctx context.Context, sessionID string, excludeParty string, message *entities.MessageDTO) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// RouteMessageUseCase handles message routing
|
// RouteMessageUseCase handles message routing
|
||||||
type RouteMessageUseCase struct {
|
type RouteMessageUseCase struct {
|
||||||
messageRepo repositories.MessageRepository
|
messageRepo repositories.MessageRepository
|
||||||
messageBroker MessageBroker
|
messageBroker MessageBroker
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRouteMessageUseCase creates a new route message use case
|
// NewRouteMessageUseCase creates a new route message use case
|
||||||
func NewRouteMessageUseCase(
|
func NewRouteMessageUseCase(
|
||||||
messageRepo repositories.MessageRepository,
|
messageRepo repositories.MessageRepository,
|
||||||
messageBroker MessageBroker,
|
messageBroker MessageBroker,
|
||||||
) *RouteMessageUseCase {
|
) *RouteMessageUseCase {
|
||||||
return &RouteMessageUseCase{
|
return &RouteMessageUseCase{
|
||||||
messageRepo: messageRepo,
|
messageRepo: messageRepo,
|
||||||
messageBroker: messageBroker,
|
messageBroker: messageBroker,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute routes an MPC message
|
// Execute routes an MPC message
|
||||||
func (uc *RouteMessageUseCase) Execute(ctx context.Context, input RouteMessageInput) (*RouteMessageOutput, error) {
|
func (uc *RouteMessageUseCase) Execute(ctx context.Context, input RouteMessageInput) (*RouteMessageOutput, error) {
|
||||||
// Validate input
|
// Validate input
|
||||||
sessionID, err := uuid.Parse(input.SessionID)
|
sessionID, err := uuid.Parse(input.SessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrInvalidSessionID
|
return nil, ErrInvalidSessionID
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.FromParty == "" {
|
if input.FromParty == "" {
|
||||||
return nil, ErrInvalidPartyID
|
return nil, ErrInvalidPartyID
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(input.Payload) == 0 {
|
if len(input.Payload) == 0 {
|
||||||
return nil, ErrEmptyPayload
|
return nil, ErrEmptyPayload
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create message entity
|
// Create message entity
|
||||||
msg := entities.NewMPCMessage(
|
msg := entities.NewMPCMessage(
|
||||||
sessionID,
|
sessionID,
|
||||||
input.FromParty,
|
input.FromParty,
|
||||||
input.ToParties,
|
input.ToParties,
|
||||||
input.RoundNumber,
|
input.RoundNumber,
|
||||||
input.MessageType,
|
input.MessageType,
|
||||||
input.Payload,
|
input.Payload,
|
||||||
)
|
)
|
||||||
|
|
||||||
// Persist message for reliability (offline scenarios)
|
// Persist message for reliability (offline scenarios)
|
||||||
if err := uc.messageRepo.Save(ctx, msg); err != nil {
|
if err := uc.messageRepo.Save(ctx, msg); err != nil {
|
||||||
logger.Error("failed to save message", zap.Error(err))
|
logger.Error("failed to save message", zap.Error(err))
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Route message
|
// Route message
|
||||||
dto := msg.ToDTO()
|
dto := msg.ToDTO()
|
||||||
if msg.IsBroadcast() {
|
if msg.IsBroadcast() {
|
||||||
// Broadcast to all parties except sender
|
// Broadcast to all parties except sender
|
||||||
if err := uc.messageBroker.PublishToSession(ctx, input.SessionID, input.FromParty, &dto); err != nil {
|
if err := uc.messageBroker.PublishToSession(ctx, input.SessionID, input.FromParty, &dto); err != nil {
|
||||||
logger.Error("failed to broadcast message",
|
logger.Error("failed to broadcast message",
|
||||||
zap.String("session_id", input.SessionID),
|
zap.String("session_id", input.SessionID),
|
||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
// Don't fail - message is persisted and can be retrieved via polling
|
// Don't fail - message is persisted and can be retrieved via polling
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Unicast to specific parties
|
// Unicast to specific parties
|
||||||
for _, toParty := range input.ToParties {
|
for _, toParty := range input.ToParties {
|
||||||
if err := uc.messageBroker.PublishToParty(ctx, toParty, &dto); err != nil {
|
if err := uc.messageBroker.PublishToParty(ctx, toParty, &dto); err != nil {
|
||||||
logger.Error("failed to send message to party",
|
logger.Error("failed to send message to party",
|
||||||
zap.String("party_id", toParty),
|
zap.String("party_id", toParty),
|
||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
// Don't fail - continue sending to other parties
|
// Don't fail - continue sending to other parties
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &RouteMessageOutput{
|
return &RouteMessageOutput{
|
||||||
MessageID: msg.ID.String(),
|
MessageID: msg.ID.String(),
|
||||||
Success: true,
|
Success: true,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPendingMessagesInput contains input for getting pending messages
|
// GetPendingMessagesInput contains input for getting pending messages
|
||||||
type GetPendingMessagesInput struct {
|
type GetPendingMessagesInput struct {
|
||||||
SessionID string
|
SessionID string
|
||||||
PartyID string
|
PartyID string
|
||||||
AfterTimestamp int64
|
AfterTimestamp int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPendingMessagesUseCase retrieves pending messages for a party
|
// GetPendingMessagesUseCase retrieves pending messages for a party
|
||||||
type GetPendingMessagesUseCase struct {
|
type GetPendingMessagesUseCase struct {
|
||||||
messageRepo repositories.MessageRepository
|
messageRepo repositories.MessageRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGetPendingMessagesUseCase creates a new get pending messages use case
|
// NewGetPendingMessagesUseCase creates a new get pending messages use case
|
||||||
func NewGetPendingMessagesUseCase(messageRepo repositories.MessageRepository) *GetPendingMessagesUseCase {
|
func NewGetPendingMessagesUseCase(messageRepo repositories.MessageRepository) *GetPendingMessagesUseCase {
|
||||||
return &GetPendingMessagesUseCase{
|
return &GetPendingMessagesUseCase{
|
||||||
messageRepo: messageRepo,
|
messageRepo: messageRepo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute retrieves pending messages
|
// Execute retrieves pending messages
|
||||||
func (uc *GetPendingMessagesUseCase) Execute(ctx context.Context, input GetPendingMessagesInput) ([]*entities.MessageDTO, error) {
|
func (uc *GetPendingMessagesUseCase) Execute(ctx context.Context, input GetPendingMessagesInput) ([]*entities.MessageDTO, error) {
|
||||||
sessionID, err := uuid.Parse(input.SessionID)
|
sessionID, err := uuid.Parse(input.SessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrInvalidSessionID
|
return nil, ErrInvalidSessionID
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.PartyID == "" {
|
if input.PartyID == "" {
|
||||||
return nil, ErrInvalidPartyID
|
return nil, ErrInvalidPartyID
|
||||||
}
|
}
|
||||||
|
|
||||||
afterTime := time.Time{}
|
afterTime := time.Time{}
|
||||||
if input.AfterTimestamp > 0 {
|
if input.AfterTimestamp > 0 {
|
||||||
afterTime = time.UnixMilli(input.AfterTimestamp)
|
afterTime = time.UnixMilli(input.AfterTimestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
messages, err := uc.messageRepo.GetPendingMessages(ctx, sessionID, input.PartyID, afterTime)
|
messages, err := uc.messageRepo.GetPendingMessages(ctx, sessionID, input.PartyID, afterTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to DTOs
|
// Convert to DTOs
|
||||||
dtos := make([]*entities.MessageDTO, len(messages))
|
dtos := make([]*entities.MessageDTO, len(messages))
|
||||||
for i, msg := range messages {
|
for i, msg := range messages {
|
||||||
dto := msg.ToDTO()
|
dto := msg.ToDTO()
|
||||||
dtos[i] = &dto
|
dtos[i] = &dto
|
||||||
}
|
}
|
||||||
|
|
||||||
return dtos, nil
|
return dtos, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,320 +1,424 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
amqp "github.com/rabbitmq/amqp091-go"
|
amqp "github.com/rabbitmq/amqp091-go"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/reflection"
|
"google.golang.org/grpc/reflection"
|
||||||
|
|
||||||
pb "github.com/rwadurian/mpc-system/api/grpc/router/v1"
|
pb "github.com/rwadurian/mpc-system/api/grpc/router/v1"
|
||||||
"github.com/rwadurian/mpc-system/pkg/config"
|
"github.com/rwadurian/mpc-system/pkg/config"
|
||||||
"github.com/rwadurian/mpc-system/pkg/logger"
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
||||||
grpcadapter "github.com/rwadurian/mpc-system/services/message-router/adapters/input/grpc"
|
grpcadapter "github.com/rwadurian/mpc-system/services/message-router/adapters/input/grpc"
|
||||||
"github.com/rwadurian/mpc-system/services/message-router/adapters/output/postgres"
|
"github.com/rwadurian/mpc-system/services/message-router/adapters/output/postgres"
|
||||||
"github.com/rwadurian/mpc-system/services/message-router/adapters/output/rabbitmq"
|
"github.com/rwadurian/mpc-system/services/message-router/adapters/output/rabbitmq"
|
||||||
"github.com/rwadurian/mpc-system/services/message-router/application/use_cases"
|
"github.com/rwadurian/mpc-system/services/message-router/application/use_cases"
|
||||||
"go.uber.org/zap"
|
"github.com/rwadurian/mpc-system/services/message-router/domain"
|
||||||
)
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
func main() {
|
|
||||||
// Parse flags
|
func main() {
|
||||||
configPath := flag.String("config", "", "Path to config file")
|
// Parse flags
|
||||||
flag.Parse()
|
configPath := flag.String("config", "", "Path to config file")
|
||||||
|
flag.Parse()
|
||||||
// Load configuration
|
|
||||||
cfg, err := config.Load(*configPath)
|
// Load configuration
|
||||||
if err != nil {
|
cfg, err := config.Load(*configPath)
|
||||||
fmt.Printf("Failed to load config: %v\n", err)
|
if err != nil {
|
||||||
os.Exit(1)
|
fmt.Printf("Failed to load config: %v\n", err)
|
||||||
}
|
os.Exit(1)
|
||||||
|
}
|
||||||
// Initialize logger
|
|
||||||
if err := logger.Init(&logger.Config{
|
// Initialize logger
|
||||||
Level: cfg.Logger.Level,
|
if err := logger.Init(&logger.Config{
|
||||||
Encoding: cfg.Logger.Encoding,
|
Level: cfg.Logger.Level,
|
||||||
}); err != nil {
|
Encoding: cfg.Logger.Encoding,
|
||||||
fmt.Printf("Failed to initialize logger: %v\n", err)
|
}); err != nil {
|
||||||
os.Exit(1)
|
fmt.Printf("Failed to initialize logger: %v\n", err)
|
||||||
}
|
os.Exit(1)
|
||||||
defer logger.Sync()
|
}
|
||||||
|
defer logger.Sync()
|
||||||
logger.Info("Starting Message Router Service",
|
|
||||||
zap.String("environment", cfg.Server.Environment),
|
logger.Info("Starting Message Router Service",
|
||||||
zap.Int("grpc_port", cfg.Server.GRPCPort),
|
zap.String("environment", cfg.Server.Environment),
|
||||||
zap.Int("http_port", cfg.Server.HTTPPort))
|
zap.Int("grpc_port", cfg.Server.GRPCPort),
|
||||||
|
zap.Int("http_port", cfg.Server.HTTPPort))
|
||||||
// Initialize database connection
|
|
||||||
db, err := initDatabase(cfg.Database)
|
// Initialize database connection
|
||||||
if err != nil {
|
db, err := initDatabase(cfg.Database)
|
||||||
logger.Fatal("Failed to connect to database", zap.Error(err))
|
if err != nil {
|
||||||
}
|
logger.Fatal("Failed to connect to database", zap.Error(err))
|
||||||
defer db.Close()
|
}
|
||||||
|
defer db.Close()
|
||||||
// Initialize RabbitMQ connection
|
|
||||||
rabbitConn, err := initRabbitMQ(cfg.RabbitMQ)
|
// Initialize RabbitMQ connection
|
||||||
if err != nil {
|
rabbitConn, err := initRabbitMQ(cfg.RabbitMQ)
|
||||||
logger.Fatal("Failed to connect to RabbitMQ", zap.Error(err))
|
if err != nil {
|
||||||
}
|
logger.Fatal("Failed to connect to RabbitMQ", zap.Error(err))
|
||||||
defer rabbitConn.Close()
|
}
|
||||||
|
defer rabbitConn.Close()
|
||||||
// Initialize repositories and adapters
|
|
||||||
messageRepo := postgres.NewMessagePostgresRepo(db)
|
// Initialize repositories and adapters
|
||||||
messageBroker, err := rabbitmq.NewMessageBrokerAdapter(rabbitConn)
|
messageRepo := postgres.NewMessagePostgresRepo(db)
|
||||||
if err != nil {
|
messageBroker, err := rabbitmq.NewMessageBrokerAdapter(rabbitConn)
|
||||||
logger.Fatal("Failed to create message broker", zap.Error(err))
|
if err != nil {
|
||||||
}
|
logger.Fatal("Failed to create message broker", zap.Error(err))
|
||||||
defer messageBroker.Close()
|
}
|
||||||
|
defer messageBroker.Close()
|
||||||
// Initialize use cases
|
|
||||||
routeMessageUC := use_cases.NewRouteMessageUseCase(messageRepo, messageBroker)
|
// Initialize party registry and event broadcaster for party-driven architecture
|
||||||
getPendingMessagesUC := use_cases.NewGetPendingMessagesUseCase(messageRepo)
|
partyRegistry := domain.NewPartyRegistry()
|
||||||
|
eventBroadcaster := domain.NewSessionEventBroadcaster()
|
||||||
// Start message cleanup background job
|
|
||||||
go runMessageCleanup(messageRepo)
|
// Initialize use cases
|
||||||
|
routeMessageUC := use_cases.NewRouteMessageUseCase(messageRepo, messageBroker)
|
||||||
// Create shutdown context
|
getPendingMessagesUC := use_cases.NewGetPendingMessagesUseCase(messageRepo)
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
// Start message cleanup background job
|
||||||
|
go runMessageCleanup(messageRepo)
|
||||||
// Start servers
|
|
||||||
errChan := make(chan error, 2)
|
// Create shutdown context
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
// Start gRPC server
|
defer cancel()
|
||||||
go func() {
|
|
||||||
if err := startGRPCServer(cfg, routeMessageUC, getPendingMessagesUC, messageBroker); err != nil {
|
// Start servers
|
||||||
errChan <- fmt.Errorf("gRPC server error: %w", err)
|
errChan := make(chan error, 2)
|
||||||
}
|
|
||||||
}()
|
// Start gRPC server
|
||||||
|
go func() {
|
||||||
// Start HTTP server
|
if err := startGRPCServer(cfg, routeMessageUC, getPendingMessagesUC, messageBroker, partyRegistry, eventBroadcaster); err != nil {
|
||||||
go func() {
|
errChan <- fmt.Errorf("gRPC server error: %w", err)
|
||||||
if err := startHTTPServer(cfg, routeMessageUC, getPendingMessagesUC); err != nil {
|
}
|
||||||
errChan <- fmt.Errorf("HTTP server error: %w", err)
|
}()
|
||||||
}
|
|
||||||
}()
|
// Start HTTP server
|
||||||
|
go func() {
|
||||||
// Wait for shutdown signal
|
if err := startHTTPServer(cfg, routeMessageUC, getPendingMessagesUC); err != nil {
|
||||||
sigChan := make(chan os.Signal, 1)
|
errChan <- fmt.Errorf("HTTP server error: %w", err)
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
}
|
||||||
|
}()
|
||||||
select {
|
|
||||||
case sig := <-sigChan:
|
// Wait for shutdown signal
|
||||||
logger.Info("Received shutdown signal", zap.String("signal", sig.String()))
|
sigChan := make(chan os.Signal, 1)
|
||||||
case err := <-errChan:
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
logger.Error("Server error", zap.Error(err))
|
|
||||||
}
|
select {
|
||||||
|
case sig := <-sigChan:
|
||||||
// Graceful shutdown
|
logger.Info("Received shutdown signal", zap.String("signal", sig.String()))
|
||||||
logger.Info("Shutting down...")
|
case err := <-errChan:
|
||||||
cancel()
|
logger.Error("Server error", zap.Error(err))
|
||||||
|
}
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
logger.Info("Shutdown complete")
|
// Graceful shutdown
|
||||||
|
logger.Info("Shutting down...")
|
||||||
_ = ctx
|
cancel()
|
||||||
}
|
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
func initDatabase(cfg config.DatabaseConfig) (*sql.DB, error) {
|
logger.Info("Shutdown complete")
|
||||||
const maxRetries = 10
|
|
||||||
const retryDelay = 2 * time.Second
|
_ = ctx
|
||||||
|
}
|
||||||
var db *sql.DB
|
|
||||||
var err error
|
func initDatabase(cfg config.DatabaseConfig) (*sql.DB, error) {
|
||||||
|
const maxRetries = 10
|
||||||
for i := 0; i < maxRetries; i++ {
|
const retryDelay = 2 * time.Second
|
||||||
db, err = sql.Open("postgres", cfg.DSN())
|
|
||||||
if err != nil {
|
var db *sql.DB
|
||||||
logger.Warn("Failed to open database connection, retrying...",
|
var err error
|
||||||
zap.Int("attempt", i+1),
|
|
||||||
zap.Int("max_retries", maxRetries),
|
for i := 0; i < maxRetries; i++ {
|
||||||
zap.Error(err))
|
db, err = sql.Open("postgres", cfg.DSN())
|
||||||
time.Sleep(retryDelay * time.Duration(i+1))
|
if err != nil {
|
||||||
continue
|
logger.Warn("Failed to open database connection, retrying...",
|
||||||
}
|
zap.Int("attempt", i+1),
|
||||||
|
zap.Int("max_retries", maxRetries),
|
||||||
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
zap.Error(err))
|
||||||
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
time.Sleep(retryDelay * time.Duration(i+1))
|
||||||
db.SetConnMaxLifetime(cfg.ConnMaxLife)
|
continue
|
||||||
|
}
|
||||||
if err = db.Ping(); err != nil {
|
|
||||||
logger.Warn("Failed to ping database, retrying...",
|
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||||
zap.Int("attempt", i+1),
|
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||||
zap.Int("max_retries", maxRetries),
|
db.SetConnMaxLifetime(cfg.ConnMaxLife)
|
||||||
zap.Error(err))
|
|
||||||
db.Close()
|
// Test connection with Ping
|
||||||
time.Sleep(retryDelay * time.Duration(i+1))
|
if err = db.Ping(); err != nil {
|
||||||
continue
|
logger.Warn("Failed to ping database, retrying...",
|
||||||
}
|
zap.Int("attempt", i+1),
|
||||||
|
zap.Int("max_retries", maxRetries),
|
||||||
logger.Info("Connected to PostgreSQL")
|
zap.Error(err))
|
||||||
return db, nil
|
db.Close()
|
||||||
}
|
time.Sleep(retryDelay * time.Duration(i+1))
|
||||||
|
continue
|
||||||
return nil, fmt.Errorf("failed to connect to database after %d retries: %w", maxRetries, err)
|
}
|
||||||
}
|
|
||||||
|
// Verify database is actually usable with a simple query
|
||||||
func initRabbitMQ(cfg config.RabbitMQConfig) (*amqp.Connection, error) {
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
const maxRetries = 10
|
var result int
|
||||||
const retryDelay = 2 * time.Second
|
err = db.QueryRowContext(ctx, "SELECT 1").Scan(&result)
|
||||||
|
cancel()
|
||||||
var conn *amqp.Connection
|
if err != nil {
|
||||||
var err error
|
logger.Warn("Database ping succeeded but query failed, retrying...",
|
||||||
|
zap.Int("attempt", i+1),
|
||||||
for i := 0; i < maxRetries; i++ {
|
zap.Int("max_retries", maxRetries),
|
||||||
conn, err = amqp.Dial(cfg.URL())
|
zap.Error(err))
|
||||||
if err != nil {
|
db.Close()
|
||||||
logger.Warn("Failed to connect to RabbitMQ, retrying...",
|
time.Sleep(retryDelay * time.Duration(i+1))
|
||||||
zap.Int("attempt", i+1),
|
continue
|
||||||
zap.Int("max_retries", maxRetries),
|
}
|
||||||
zap.Error(err))
|
|
||||||
time.Sleep(retryDelay * time.Duration(i+1))
|
logger.Info("Connected to PostgreSQL and verified connectivity",
|
||||||
continue
|
zap.Int("attempt", i+1))
|
||||||
}
|
return db, nil
|
||||||
|
}
|
||||||
logger.Info("Connected to RabbitMQ")
|
|
||||||
return conn, nil
|
return nil, fmt.Errorf("failed to connect to database after %d retries: %w", maxRetries, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("failed to connect to RabbitMQ after %d retries: %w", maxRetries, err)
|
func initRabbitMQ(cfg config.RabbitMQConfig) (*amqp.Connection, error) {
|
||||||
}
|
const maxRetries = 10
|
||||||
|
const retryDelay = 2 * time.Second
|
||||||
func startGRPCServer(
|
|
||||||
cfg *config.Config,
|
var conn *amqp.Connection
|
||||||
routeMessageUC *use_cases.RouteMessageUseCase,
|
var err error
|
||||||
getPendingMessagesUC *use_cases.GetPendingMessagesUseCase,
|
|
||||||
messageBroker *rabbitmq.MessageBrokerAdapter,
|
for i := 0; i < maxRetries; i++ {
|
||||||
) error {
|
// Attempt to dial RabbitMQ
|
||||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", cfg.Server.GRPCPort))
|
conn, err = amqp.Dial(cfg.URL())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
logger.Warn("Failed to dial RabbitMQ, retrying...",
|
||||||
}
|
zap.Int("attempt", i+1),
|
||||||
|
zap.Int("max_retries", maxRetries),
|
||||||
grpcServer := grpc.NewServer()
|
zap.String("url", maskPassword(cfg.URL())),
|
||||||
|
zap.Error(err))
|
||||||
// Create and register the message router gRPC handler
|
time.Sleep(retryDelay * time.Duration(i+1))
|
||||||
messageRouterServer := grpcadapter.NewMessageRouterServer(
|
continue
|
||||||
routeMessageUC,
|
}
|
||||||
getPendingMessagesUC,
|
|
||||||
messageBroker,
|
// Verify connection is actually usable by opening a channel
|
||||||
)
|
ch, err := conn.Channel()
|
||||||
pb.RegisterMessageRouterServer(grpcServer, messageRouterServer)
|
if err != nil {
|
||||||
|
logger.Warn("RabbitMQ connection established but channel creation failed, retrying...",
|
||||||
// Enable reflection for debugging
|
zap.Int("attempt", i+1),
|
||||||
reflection.Register(grpcServer)
|
zap.Int("max_retries", maxRetries),
|
||||||
|
zap.Error(err))
|
||||||
logger.Info("Starting gRPC server", zap.Int("port", cfg.Server.GRPCPort))
|
conn.Close()
|
||||||
return grpcServer.Serve(listener)
|
time.Sleep(retryDelay * time.Duration(i+1))
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
func startHTTPServer(
|
|
||||||
cfg *config.Config,
|
// Test the channel with a simple operation (declare a test exchange)
|
||||||
routeMessageUC *use_cases.RouteMessageUseCase,
|
err = ch.ExchangeDeclare(
|
||||||
getPendingMessagesUC *use_cases.GetPendingMessagesUseCase,
|
"mpc.health.check", // name
|
||||||
) error {
|
"fanout", // type
|
||||||
if cfg.Server.Environment == "production" {
|
false, // durable
|
||||||
gin.SetMode(gin.ReleaseMode)
|
true, // auto-deleted
|
||||||
}
|
false, // internal
|
||||||
|
false, // no-wait
|
||||||
router := gin.New()
|
nil, // arguments
|
||||||
router.Use(gin.Recovery())
|
)
|
||||||
router.Use(gin.Logger())
|
if err != nil {
|
||||||
|
logger.Warn("RabbitMQ channel created but exchange declaration failed, retrying...",
|
||||||
// Health check
|
zap.Int("attempt", i+1),
|
||||||
router.GET("/health", func(c *gin.Context) {
|
zap.Int("max_retries", maxRetries),
|
||||||
c.JSON(http.StatusOK, gin.H{
|
zap.Error(err))
|
||||||
"status": "healthy",
|
ch.Close()
|
||||||
"service": "message-router",
|
conn.Close()
|
||||||
})
|
time.Sleep(retryDelay * time.Duration(i+1))
|
||||||
})
|
continue
|
||||||
|
}
|
||||||
// API routes
|
|
||||||
api := router.Group("/api/v1")
|
// Clean up test exchange
|
||||||
{
|
ch.ExchangeDelete("mpc.health.check", false, false)
|
||||||
api.POST("/messages/route", func(c *gin.Context) {
|
ch.Close()
|
||||||
var req struct {
|
|
||||||
SessionID string `json:"session_id" binding:"required"`
|
// Setup connection close notification
|
||||||
FromParty string `json:"from_party" binding:"required"`
|
closeChan := make(chan *amqp.Error, 1)
|
||||||
ToParties []string `json:"to_parties"`
|
conn.NotifyClose(closeChan)
|
||||||
RoundNumber int `json:"round_number"`
|
go func() {
|
||||||
MessageType string `json:"message_type"`
|
err := <-closeChan
|
||||||
Payload []byte `json:"payload" binding:"required"`
|
if err != nil {
|
||||||
}
|
logger.Error("RabbitMQ connection closed unexpectedly", zap.Error(err))
|
||||||
|
}
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
}()
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
logger.Info("Connected to RabbitMQ and verified connectivity",
|
||||||
}
|
zap.Int("attempt", i+1))
|
||||||
|
return conn, nil
|
||||||
input := use_cases.RouteMessageInput{
|
}
|
||||||
SessionID: req.SessionID,
|
|
||||||
FromParty: req.FromParty,
|
return nil, fmt.Errorf("failed to connect to RabbitMQ after %d retries: %w", maxRetries, err)
|
||||||
ToParties: req.ToParties,
|
}
|
||||||
RoundNumber: req.RoundNumber,
|
|
||||||
MessageType: req.MessageType,
|
// maskPassword masks the password in the RabbitMQ URL for logging
|
||||||
Payload: req.Payload,
|
func maskPassword(url string) string {
|
||||||
}
|
// Simple masking: amqp://user:password@host:port -> amqp://user:****@host:port
|
||||||
|
start := 0
|
||||||
output, err := routeMessageUC.Execute(c.Request.Context(), input)
|
for i := 0; i < len(url); i++ {
|
||||||
if err != nil {
|
if url[i] == ':' && i > 0 && url[i-1] != '/' {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
start = i + 1
|
||||||
return
|
break
|
||||||
}
|
}
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
if start == 0 {
|
||||||
"success": output.Success,
|
return url
|
||||||
"message_id": output.MessageID,
|
}
|
||||||
})
|
|
||||||
})
|
end := start
|
||||||
|
for i := start; i < len(url); i++ {
|
||||||
api.GET("/messages/pending", func(c *gin.Context) {
|
if url[i] == '@' {
|
||||||
input := use_cases.GetPendingMessagesInput{
|
end = i
|
||||||
SessionID: c.Query("session_id"),
|
break
|
||||||
PartyID: c.Query("party_id"),
|
}
|
||||||
AfterTimestamp: 0,
|
}
|
||||||
}
|
if end == start {
|
||||||
|
return url
|
||||||
messages, err := getPendingMessagesUC.Execute(c.Request.Context(), input)
|
}
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
return url[:start] + "****" + url[end:]
|
||||||
return
|
}
|
||||||
}
|
|
||||||
|
func startGRPCServer(
|
||||||
c.JSON(http.StatusOK, gin.H{"messages": messages})
|
cfg *config.Config,
|
||||||
})
|
routeMessageUC *use_cases.RouteMessageUseCase,
|
||||||
}
|
getPendingMessagesUC *use_cases.GetPendingMessagesUseCase,
|
||||||
|
messageBroker *rabbitmq.MessageBrokerAdapter,
|
||||||
logger.Info("Starting HTTP server", zap.Int("port", cfg.Server.HTTPPort))
|
partyRegistry *domain.PartyRegistry,
|
||||||
return router.Run(fmt.Sprintf(":%d", cfg.Server.HTTPPort))
|
eventBroadcaster *domain.SessionEventBroadcaster,
|
||||||
}
|
) error {
|
||||||
|
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", cfg.Server.GRPCPort))
|
||||||
func runMessageCleanup(messageRepo *postgres.MessagePostgresRepo) {
|
if err != nil {
|
||||||
ticker := time.NewTicker(1 * time.Hour)
|
return err
|
||||||
defer ticker.Stop()
|
}
|
||||||
|
|
||||||
for range ticker.C {
|
grpcServer := grpc.NewServer()
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
|
||||||
|
// Create and register the message router gRPC handler with party registry and event broadcaster
|
||||||
// Delete messages older than 24 hours
|
messageRouterServer := grpcadapter.NewMessageRouterServer(
|
||||||
cutoff := time.Now().Add(-24 * time.Hour)
|
routeMessageUC,
|
||||||
count, err := messageRepo.DeleteOlderThan(ctx, cutoff)
|
getPendingMessagesUC,
|
||||||
cancel()
|
messageBroker,
|
||||||
|
partyRegistry,
|
||||||
if err != nil {
|
eventBroadcaster,
|
||||||
logger.Error("Failed to cleanup old messages", zap.Error(err))
|
)
|
||||||
} else if count > 0 {
|
pb.RegisterMessageRouterServer(grpcServer, messageRouterServer)
|
||||||
logger.Info("Cleaned up old messages", zap.Int64("count", count))
|
|
||||||
}
|
// Enable reflection for debugging
|
||||||
}
|
reflection.Register(grpcServer)
|
||||||
}
|
|
||||||
|
logger.Info("Starting gRPC server", zap.Int("port", cfg.Server.GRPCPort))
|
||||||
|
return grpcServer.Serve(listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
func startHTTPServer(
|
||||||
|
cfg *config.Config,
|
||||||
|
routeMessageUC *use_cases.RouteMessageUseCase,
|
||||||
|
getPendingMessagesUC *use_cases.GetPendingMessagesUseCase,
|
||||||
|
) error {
|
||||||
|
if cfg.Server.Environment == "production" {
|
||||||
|
gin.SetMode(gin.ReleaseMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(gin.Recovery())
|
||||||
|
router.Use(gin.Logger())
|
||||||
|
|
||||||
|
// Health check
|
||||||
|
router.GET("/health", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": "healthy",
|
||||||
|
"service": "message-router",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// API routes
|
||||||
|
api := router.Group("/api/v1")
|
||||||
|
{
|
||||||
|
api.POST("/messages/route", func(c *gin.Context) {
|
||||||
|
var req struct {
|
||||||
|
SessionID string `json:"session_id" binding:"required"`
|
||||||
|
FromParty string `json:"from_party" binding:"required"`
|
||||||
|
ToParties []string `json:"to_parties"`
|
||||||
|
RoundNumber int `json:"round_number"`
|
||||||
|
MessageType string `json:"message_type"`
|
||||||
|
Payload []byte `json:"payload" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
input := use_cases.RouteMessageInput{
|
||||||
|
SessionID: req.SessionID,
|
||||||
|
FromParty: req.FromParty,
|
||||||
|
ToParties: req.ToParties,
|
||||||
|
RoundNumber: req.RoundNumber,
|
||||||
|
MessageType: req.MessageType,
|
||||||
|
Payload: req.Payload,
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := routeMessageUC.Execute(c.Request.Context(), input)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": output.Success,
|
||||||
|
"message_id": output.MessageID,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
api.GET("/messages/pending", func(c *gin.Context) {
|
||||||
|
input := use_cases.GetPendingMessagesInput{
|
||||||
|
SessionID: c.Query("session_id"),
|
||||||
|
PartyID: c.Query("party_id"),
|
||||||
|
AfterTimestamp: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, err := getPendingMessagesUC.Execute(c.Request.Context(), input)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"messages": messages})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Starting HTTP server", zap.Int("port", cfg.Server.HTTPPort))
|
||||||
|
return router.Run(fmt.Sprintf(":%d", cfg.Server.HTTPPort))
|
||||||
|
}
|
||||||
|
|
||||||
|
func runMessageCleanup(messageRepo *postgres.MessagePostgresRepo) {
|
||||||
|
ticker := time.NewTicker(1 * time.Hour)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||||
|
|
||||||
|
// Delete messages older than 24 hours
|
||||||
|
cutoff := time.Now().Add(-24 * time.Hour)
|
||||||
|
count, err := messageRepo.DeleteOlderThan(ctx, cutoff)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to cleanup old messages", zap.Error(err))
|
||||||
|
} else if count > 0 {
|
||||||
|
logger.Info("Cleaned up old messages", zap.Int64("count", count))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,100 +1,100 @@
|
||||||
package entities
|
package entities
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MPCMessage represents an MPC protocol message
|
// MPCMessage represents an MPC protocol message
|
||||||
type MPCMessage struct {
|
type MPCMessage struct {
|
||||||
ID uuid.UUID
|
ID uuid.UUID
|
||||||
SessionID uuid.UUID
|
SessionID uuid.UUID
|
||||||
FromParty string
|
FromParty string
|
||||||
ToParties []string // nil means broadcast
|
ToParties []string // nil means broadcast
|
||||||
RoundNumber int
|
RoundNumber int
|
||||||
MessageType string
|
MessageType string
|
||||||
Payload []byte // Encrypted MPC message (router does not decrypt)
|
Payload []byte // Encrypted MPC message (router does not decrypt)
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
DeliveredAt *time.Time
|
DeliveredAt *time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMPCMessage creates a new MPC message
|
// NewMPCMessage creates a new MPC message
|
||||||
func NewMPCMessage(
|
func NewMPCMessage(
|
||||||
sessionID uuid.UUID,
|
sessionID uuid.UUID,
|
||||||
fromParty string,
|
fromParty string,
|
||||||
toParties []string,
|
toParties []string,
|
||||||
roundNumber int,
|
roundNumber int,
|
||||||
messageType string,
|
messageType string,
|
||||||
payload []byte,
|
payload []byte,
|
||||||
) *MPCMessage {
|
) *MPCMessage {
|
||||||
return &MPCMessage{
|
return &MPCMessage{
|
||||||
ID: uuid.New(),
|
ID: uuid.New(),
|
||||||
SessionID: sessionID,
|
SessionID: sessionID,
|
||||||
FromParty: fromParty,
|
FromParty: fromParty,
|
||||||
ToParties: toParties,
|
ToParties: toParties,
|
||||||
RoundNumber: roundNumber,
|
RoundNumber: roundNumber,
|
||||||
MessageType: messageType,
|
MessageType: messageType,
|
||||||
Payload: payload,
|
Payload: payload,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsBroadcast checks if the message is a broadcast
|
// IsBroadcast checks if the message is a broadcast
|
||||||
func (m *MPCMessage) IsBroadcast() bool {
|
func (m *MPCMessage) IsBroadcast() bool {
|
||||||
return len(m.ToParties) == 0
|
return len(m.ToParties) == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsFor checks if the message is for a specific party
|
// IsFor checks if the message is for a specific party
|
||||||
func (m *MPCMessage) IsFor(partyID string) bool {
|
func (m *MPCMessage) IsFor(partyID string) bool {
|
||||||
if m.IsBroadcast() {
|
if m.IsBroadcast() {
|
||||||
// Broadcast is for everyone except sender
|
// Broadcast is for everyone except sender
|
||||||
return m.FromParty != partyID
|
return m.FromParty != partyID
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, to := range m.ToParties {
|
for _, to := range m.ToParties {
|
||||||
if to == partyID {
|
if to == partyID {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkDelivered marks the message as delivered
|
// MarkDelivered marks the message as delivered
|
||||||
func (m *MPCMessage) MarkDelivered() {
|
func (m *MPCMessage) MarkDelivered() {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
m.DeliveredAt = &now
|
m.DeliveredAt = &now
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsDelivered checks if the message has been delivered
|
// IsDelivered checks if the message has been delivered
|
||||||
func (m *MPCMessage) IsDelivered() bool {
|
func (m *MPCMessage) IsDelivered() bool {
|
||||||
return m.DeliveredAt != nil
|
return m.DeliveredAt != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToDTO converts to DTO
|
// ToDTO converts to DTO
|
||||||
func (m *MPCMessage) ToDTO() MessageDTO {
|
func (m *MPCMessage) ToDTO() MessageDTO {
|
||||||
return MessageDTO{
|
return MessageDTO{
|
||||||
ID: m.ID.String(),
|
ID: m.ID.String(),
|
||||||
SessionID: m.SessionID.String(),
|
SessionID: m.SessionID.String(),
|
||||||
FromParty: m.FromParty,
|
FromParty: m.FromParty,
|
||||||
ToParties: m.ToParties,
|
ToParties: m.ToParties,
|
||||||
IsBroadcast: m.IsBroadcast(),
|
IsBroadcast: m.IsBroadcast(),
|
||||||
RoundNumber: m.RoundNumber,
|
RoundNumber: m.RoundNumber,
|
||||||
MessageType: m.MessageType,
|
MessageType: m.MessageType,
|
||||||
Payload: m.Payload,
|
Payload: m.Payload,
|
||||||
CreatedAt: m.CreatedAt.UnixMilli(),
|
CreatedAt: m.CreatedAt.UnixMilli(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MessageDTO is a data transfer object for messages
|
// MessageDTO is a data transfer object for messages
|
||||||
type MessageDTO struct {
|
type MessageDTO struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
SessionID string `json:"session_id"`
|
SessionID string `json:"session_id"`
|
||||||
FromParty string `json:"from_party"`
|
FromParty string `json:"from_party"`
|
||||||
ToParties []string `json:"to_parties,omitempty"`
|
ToParties []string `json:"to_parties,omitempty"`
|
||||||
IsBroadcast bool `json:"is_broadcast"`
|
IsBroadcast bool `json:"is_broadcast"`
|
||||||
RoundNumber int `json:"round_number"`
|
RoundNumber int `json:"round_number"`
|
||||||
MessageType string `json:"message_type"`
|
MessageType string `json:"message_type"`
|
||||||
Payload []byte `json:"payload"`
|
Payload []byte `json:"payload"`
|
||||||
CreatedAt int64 `json:"created_at"`
|
CreatedAt int64 `json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisteredParty represents a party registered with the router
|
||||||
|
type RegisteredParty struct {
|
||||||
|
PartyID string
|
||||||
|
Role string // persistent, delegate, temporary
|
||||||
|
Version string
|
||||||
|
RegisteredAt time.Time
|
||||||
|
LastSeen time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// PartyRegistry manages registered parties
|
||||||
|
type PartyRegistry struct {
|
||||||
|
parties map[string]*RegisteredParty
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPartyRegistry creates a new party registry
|
||||||
|
func NewPartyRegistry() *PartyRegistry {
|
||||||
|
return &PartyRegistry{
|
||||||
|
parties: make(map[string]*RegisteredParty),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register registers a party
|
||||||
|
func (r *PartyRegistry) Register(partyID, role, version string) *RegisteredParty {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
party := &RegisteredParty{
|
||||||
|
PartyID: partyID,
|
||||||
|
Role: role,
|
||||||
|
Version: version,
|
||||||
|
RegisteredAt: now,
|
||||||
|
LastSeen: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
r.parties[partyID] = party
|
||||||
|
return party
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a registered party
|
||||||
|
func (r *PartyRegistry) Get(partyID string) (*RegisteredParty, bool) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
party, exists := r.parties[partyID]
|
||||||
|
return party, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAll returns all registered parties
|
||||||
|
func (r *PartyRegistry) GetAll() []*RegisteredParty {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
parties := make([]*RegisteredParty, 0, len(r.parties))
|
||||||
|
for _, party := range r.parties {
|
||||||
|
parties = append(parties, party)
|
||||||
|
}
|
||||||
|
return parties
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateLastSeen updates the last seen timestamp
|
||||||
|
func (r *PartyRegistry) UpdateLastSeen(partyID string) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if party, exists := r.parties[partyID]; exists {
|
||||||
|
party.LastSeen = time.Now()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unregister removes a party from the registry
|
||||||
|
func (r *PartyRegistry) Unregister(partyID string) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
delete(r.parties, partyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the number of registered parties
|
||||||
|
func (r *PartyRegistry) Count() int {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
return len(r.parties)
|
||||||
|
}
|
||||||
|
|
@ -1,33 +1,33 @@
|
||||||
package repositories
|
package repositories
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rwadurian/mpc-system/services/message-router/domain/entities"
|
"github.com/rwadurian/mpc-system/services/message-router/domain/entities"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MessageRepository defines the interface for message persistence
|
// MessageRepository defines the interface for message persistence
|
||||||
type MessageRepository interface {
|
type MessageRepository interface {
|
||||||
// Save persists a new message
|
// Save persists a new message
|
||||||
Save(ctx context.Context, msg *entities.MPCMessage) error
|
Save(ctx context.Context, msg *entities.MPCMessage) error
|
||||||
|
|
||||||
// GetByID retrieves a message by ID
|
// GetByID retrieves a message by ID
|
||||||
GetByID(ctx context.Context, id uuid.UUID) (*entities.MPCMessage, error)
|
GetByID(ctx context.Context, id uuid.UUID) (*entities.MPCMessage, error)
|
||||||
|
|
||||||
// GetPendingMessages retrieves pending messages for a party
|
// GetPendingMessages retrieves pending messages for a party
|
||||||
GetPendingMessages(ctx context.Context, sessionID uuid.UUID, partyID string, afterTime time.Time) ([]*entities.MPCMessage, error)
|
GetPendingMessages(ctx context.Context, sessionID uuid.UUID, partyID string, afterTime time.Time) ([]*entities.MPCMessage, error)
|
||||||
|
|
||||||
// GetMessagesByRound retrieves messages for a specific round
|
// GetMessagesByRound retrieves messages for a specific round
|
||||||
GetMessagesByRound(ctx context.Context, sessionID uuid.UUID, roundNumber int) ([]*entities.MPCMessage, error)
|
GetMessagesByRound(ctx context.Context, sessionID uuid.UUID, roundNumber int) ([]*entities.MPCMessage, error)
|
||||||
|
|
||||||
// MarkDelivered marks a message as delivered
|
// MarkDelivered marks a message as delivered
|
||||||
MarkDelivered(ctx context.Context, messageID uuid.UUID) error
|
MarkDelivered(ctx context.Context, messageID uuid.UUID) error
|
||||||
|
|
||||||
// DeleteBySession deletes all messages for a session
|
// DeleteBySession deletes all messages for a session
|
||||||
DeleteBySession(ctx context.Context, sessionID uuid.UUID) error
|
DeleteBySession(ctx context.Context, sessionID uuid.UUID) error
|
||||||
|
|
||||||
// DeleteOlderThan deletes messages older than a specific time
|
// DeleteOlderThan deletes messages older than a specific time
|
||||||
DeleteOlderThan(ctx context.Context, before time.Time) (int64, error)
|
DeleteOlderThan(ctx context.Context, before time.Time) (int64, error)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,83 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
pb "github.com/rwadurian/mpc-system/api/grpc/router/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionEventBroadcaster manages session event subscriptions and broadcasting
|
||||||
|
type SessionEventBroadcaster struct {
|
||||||
|
subscribers map[string]chan *pb.SessionEvent // partyID -> event channel
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionEventBroadcaster creates a new session event broadcaster
|
||||||
|
func NewSessionEventBroadcaster() *SessionEventBroadcaster {
|
||||||
|
return &SessionEventBroadcaster{
|
||||||
|
subscribers: make(map[string]chan *pb.SessionEvent),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe subscribes a party to session events
|
||||||
|
func (b *SessionEventBroadcaster) Subscribe(partyID string) <-chan *pb.SessionEvent {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
// Create buffered channel for this subscriber
|
||||||
|
ch := make(chan *pb.SessionEvent, 100)
|
||||||
|
b.subscribers[partyID] = ch
|
||||||
|
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe removes a party's subscription
|
||||||
|
func (b *SessionEventBroadcaster) Unsubscribe(partyID string) {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
if ch, exists := b.subscribers[partyID]; exists {
|
||||||
|
close(ch)
|
||||||
|
delete(b.subscribers, partyID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast sends an event to all subscribers
|
||||||
|
func (b *SessionEventBroadcaster) Broadcast(event *pb.SessionEvent) {
|
||||||
|
b.mu.RLock()
|
||||||
|
defer b.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, ch := range b.subscribers {
|
||||||
|
// Non-blocking send to prevent slow subscribers from blocking
|
||||||
|
select {
|
||||||
|
case ch <- event:
|
||||||
|
default:
|
||||||
|
// Channel full, skip this subscriber
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BroadcastToParties sends an event to specific parties only
|
||||||
|
func (b *SessionEventBroadcaster) BroadcastToParties(event *pb.SessionEvent, partyIDs []string) {
|
||||||
|
b.mu.RLock()
|
||||||
|
defer b.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, partyID := range partyIDs {
|
||||||
|
if ch, exists := b.subscribers[partyID]; exists {
|
||||||
|
// Non-blocking send
|
||||||
|
select {
|
||||||
|
case ch <- event:
|
||||||
|
default:
|
||||||
|
// Channel full, skip this subscriber
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubscriberCount returns the number of active subscribers
|
||||||
|
func (b *SessionEventBroadcaster) SubscriberCount() int {
|
||||||
|
b.mu.RLock()
|
||||||
|
defer b.mu.RUnlock()
|
||||||
|
|
||||||
|
return len(b.subscribers)
|
||||||
|
}
|
||||||
|
|
@ -1,38 +1,38 @@
|
||||||
# Build stage
|
# Build stage
|
||||||
FROM golang:1.21-alpine AS builder
|
FROM golang:1.21-alpine AS builder
|
||||||
|
|
||||||
RUN apk add --no-cache git ca-certificates
|
RUN apk add --no-cache git ca-certificates
|
||||||
|
|
||||||
# Set Go proxy (can be overridden with --build-arg GOPROXY=...)
|
# Set Go proxy (can be overridden with --build-arg GOPROXY=...)
|
||||||
ARG GOPROXY=https://proxy.golang.org,direct
|
ARG GOPROXY=https://proxy.golang.org,direct
|
||||||
ENV GOPROXY=${GOPROXY}
|
ENV GOPROXY=${GOPROXY}
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
COPY go.mod go.sum ./
|
COPY go.mod go.sum ./
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
||||||
-ldflags="-w -s" \
|
-ldflags="-w -s" \
|
||||||
-o /bin/server-party-api \
|
-o /bin/server-party-api \
|
||||||
./services/server-party-api/cmd/server
|
./services/server-party-api/cmd/server
|
||||||
|
|
||||||
# Final stage
|
# Final stage
|
||||||
FROM alpine:3.18
|
FROM alpine:3.18
|
||||||
|
|
||||||
RUN apk --no-cache add ca-certificates curl
|
RUN apk --no-cache add ca-certificates curl
|
||||||
RUN adduser -D -s /bin/sh mpc
|
RUN adduser -D -s /bin/sh mpc
|
||||||
|
|
||||||
COPY --from=builder /bin/server-party-api /bin/server-party-api
|
COPY --from=builder /bin/server-party-api /bin/server-party-api
|
||||||
|
|
||||||
USER mpc
|
USER mpc
|
||||||
|
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
|
|
||||||
# Health check
|
# Health check
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||||
CMD curl -sf http://localhost:8080/health || exit 1
|
CMD curl -sf http://localhost:8080/health || exit 1
|
||||||
|
|
||||||
ENTRYPOINT ["/bin/server-party-api"]
|
ENTRYPOINT ["/bin/server-party-api"]
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,38 +1,38 @@
|
||||||
# Build stage
|
# Build stage
|
||||||
FROM golang:1.21-alpine AS builder
|
FROM golang:1.21-alpine AS builder
|
||||||
|
|
||||||
RUN apk add --no-cache git ca-certificates
|
RUN apk add --no-cache git ca-certificates
|
||||||
|
|
||||||
# Set Go proxy (can be overridden with --build-arg GOPROXY=...)
|
# Set Go proxy (can be overridden with --build-arg GOPROXY=...)
|
||||||
ARG GOPROXY=https://proxy.golang.org,direct
|
ARG GOPROXY=https://proxy.golang.org,direct
|
||||||
ENV GOPROXY=${GOPROXY}
|
ENV GOPROXY=${GOPROXY}
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
COPY go.mod go.sum ./
|
COPY go.mod go.sum ./
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
||||||
-ldflags="-w -s" \
|
-ldflags="-w -s" \
|
||||||
-o /bin/server-party \
|
-o /bin/server-party \
|
||||||
./services/server-party/cmd/server
|
./services/server-party/cmd/server
|
||||||
|
|
||||||
# Final stage
|
# Final stage
|
||||||
FROM alpine:3.18
|
FROM alpine:3.18
|
||||||
|
|
||||||
RUN apk --no-cache add ca-certificates curl
|
RUN apk --no-cache add ca-certificates curl
|
||||||
RUN adduser -D -s /bin/sh mpc
|
RUN adduser -D -s /bin/sh mpc
|
||||||
|
|
||||||
COPY --from=builder /bin/server-party /bin/server-party
|
COPY --from=builder /bin/server-party /bin/server-party
|
||||||
|
|
||||||
USER mpc
|
USER mpc
|
||||||
|
|
||||||
EXPOSE 50051 8080
|
EXPOSE 50051 8080
|
||||||
|
|
||||||
# Health check
|
# Health check
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||||
CMD curl -sf http://localhost:8080/health || exit 1
|
CMD curl -sf http://localhost:8080/health || exit 1
|
||||||
|
|
||||||
ENTRYPOINT ["/bin/server-party"]
|
ENTRYPOINT ["/bin/server-party"]
|
||||||
|
|
|
||||||
|
|
@ -1,229 +1,229 @@
|
||||||
package grpc
|
package grpc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
router "github.com/rwadurian/mpc-system/api/grpc/router/v1"
|
router "github.com/rwadurian/mpc-system/api/grpc/router/v1"
|
||||||
"github.com/rwadurian/mpc-system/pkg/logger"
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
||||||
"github.com/rwadurian/mpc-system/services/server-party/application/use_cases"
|
"github.com/rwadurian/mpc-system/services/server-party/application/use_cases"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MessageRouterClient implements use_cases.MessageRouterClient
|
// MessageRouterClient implements use_cases.MessageRouterClient
|
||||||
type MessageRouterClient struct {
|
type MessageRouterClient struct {
|
||||||
conn *grpc.ClientConn
|
conn *grpc.ClientConn
|
||||||
address string
|
address string
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMessageRouterClient creates a new message router gRPC client
|
// NewMessageRouterClient creates a new message router gRPC client
|
||||||
func NewMessageRouterClient(address string) (*MessageRouterClient, error) {
|
func NewMessageRouterClient(address string) (*MessageRouterClient, error) {
|
||||||
conn, err := grpc.Dial(
|
conn, err := grpc.Dial(
|
||||||
address,
|
address,
|
||||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
grpc.WithBlock(),
|
grpc.WithBlock(),
|
||||||
grpc.WithTimeout(10*time.Second),
|
grpc.WithTimeout(10*time.Second),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Connected to Message Router", zap.String("address", address))
|
logger.Info("Connected to Message Router", zap.String("address", address))
|
||||||
|
|
||||||
return &MessageRouterClient{
|
return &MessageRouterClient{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
address: address,
|
address: address,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the gRPC connection
|
// Close closes the gRPC connection
|
||||||
func (c *MessageRouterClient) Close() error {
|
func (c *MessageRouterClient) Close() error {
|
||||||
if c.conn != nil {
|
if c.conn != nil {
|
||||||
return c.conn.Close()
|
return c.conn.Close()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RouteMessage sends an MPC protocol message to other parties
|
// RouteMessage sends an MPC protocol message to other parties
|
||||||
func (c *MessageRouterClient) RouteMessage(
|
func (c *MessageRouterClient) RouteMessage(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID uuid.UUID,
|
sessionID uuid.UUID,
|
||||||
fromParty string,
|
fromParty string,
|
||||||
toParties []string,
|
toParties []string,
|
||||||
roundNumber int,
|
roundNumber int,
|
||||||
payload []byte,
|
payload []byte,
|
||||||
) error {
|
) error {
|
||||||
req := &router.RouteMessageRequest{
|
req := &router.RouteMessageRequest{
|
||||||
SessionId: sessionID.String(),
|
SessionId: sessionID.String(),
|
||||||
FromParty: fromParty,
|
FromParty: fromParty,
|
||||||
ToParties: toParties,
|
ToParties: toParties,
|
||||||
RoundNumber: int32(roundNumber),
|
RoundNumber: int32(roundNumber),
|
||||||
MessageType: "tss",
|
MessageType: "tss",
|
||||||
Payload: payload,
|
Payload: payload,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := &router.RouteMessageResponse{}
|
resp := &router.RouteMessageResponse{}
|
||||||
err := c.conn.Invoke(ctx, "/mpc.router.v1.MessageRouter/RouteMessage", req, resp)
|
err := c.conn.Invoke(ctx, "/mpc.router.v1.MessageRouter/RouteMessage", req, resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to route message",
|
logger.Error("Failed to route message",
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
zap.String("session_id", sessionID.String()),
|
zap.String("session_id", sessionID.String()),
|
||||||
zap.String("from", fromParty))
|
zap.String("from", fromParty))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !resp.Success {
|
if !resp.Success {
|
||||||
logger.Error("Message routing failed",
|
logger.Error("Message routing failed",
|
||||||
zap.String("session_id", sessionID.String()))
|
zap.String("session_id", sessionID.String()))
|
||||||
return use_cases.ErrKeygenFailed
|
return use_cases.ErrKeygenFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("Message routed successfully",
|
logger.Debug("Message routed successfully",
|
||||||
zap.String("session_id", sessionID.String()),
|
zap.String("session_id", sessionID.String()),
|
||||||
zap.String("from", fromParty),
|
zap.String("from", fromParty),
|
||||||
zap.Int("to_count", len(toParties)),
|
zap.Int("to_count", len(toParties)),
|
||||||
zap.Int("round", roundNumber))
|
zap.Int("round", roundNumber))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SubscribeMessages subscribes to MPC messages for a party
|
// SubscribeMessages subscribes to MPC messages for a party
|
||||||
func (c *MessageRouterClient) SubscribeMessages(
|
func (c *MessageRouterClient) SubscribeMessages(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID uuid.UUID,
|
sessionID uuid.UUID,
|
||||||
partyID string,
|
partyID string,
|
||||||
) (<-chan *use_cases.MPCMessage, error) {
|
) (<-chan *use_cases.MPCMessage, error) {
|
||||||
req := &router.SubscribeMessagesRequest{
|
req := &router.SubscribeMessagesRequest{
|
||||||
SessionId: sessionID.String(),
|
SessionId: sessionID.String(),
|
||||||
PartyId: partyID,
|
PartyId: partyID,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a streaming connection
|
// Create a streaming connection
|
||||||
stream, err := c.createSubscribeStream(ctx, req)
|
stream, err := c.createSubscribeStream(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to subscribe to messages",
|
logger.Error("Failed to subscribe to messages",
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
zap.String("session_id", sessionID.String()),
|
zap.String("session_id", sessionID.String()),
|
||||||
zap.String("party_id", partyID))
|
zap.String("party_id", partyID))
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create output channel
|
// Create output channel
|
||||||
msgChan := make(chan *use_cases.MPCMessage, 100)
|
msgChan := make(chan *use_cases.MPCMessage, 100)
|
||||||
|
|
||||||
// Start goroutine to receive messages
|
// Start goroutine to receive messages
|
||||||
go func() {
|
go func() {
|
||||||
defer close(msgChan)
|
defer close(msgChan)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
logger.Debug("Message subscription context cancelled",
|
logger.Debug("Message subscription context cancelled",
|
||||||
zap.String("session_id", sessionID.String()),
|
zap.String("session_id", sessionID.String()),
|
||||||
zap.String("party_id", partyID))
|
zap.String("party_id", partyID))
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
msg := &router.MPCMessage{}
|
msg := &router.MPCMessage{}
|
||||||
err := stream.RecvMsg(msg)
|
err := stream.RecvMsg(msg)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
logger.Debug("Message stream ended",
|
logger.Debug("Message stream ended",
|
||||||
zap.String("session_id", sessionID.String()))
|
zap.String("session_id", sessionID.String()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error receiving message",
|
logger.Error("Error receiving message",
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
zap.String("session_id", sessionID.String()))
|
zap.String("session_id", sessionID.String()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to use_cases.MPCMessage
|
// Convert to use_cases.MPCMessage
|
||||||
mpcMsg := &use_cases.MPCMessage{
|
mpcMsg := &use_cases.MPCMessage{
|
||||||
FromParty: msg.FromParty,
|
FromParty: msg.FromParty,
|
||||||
IsBroadcast: msg.IsBroadcast,
|
IsBroadcast: msg.IsBroadcast,
|
||||||
RoundNumber: int(msg.RoundNumber),
|
RoundNumber: int(msg.RoundNumber),
|
||||||
Payload: msg.Payload,
|
Payload: msg.Payload,
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case msgChan <- mpcMsg:
|
case msgChan <- mpcMsg:
|
||||||
logger.Debug("Received MPC message",
|
logger.Debug("Received MPC message",
|
||||||
zap.String("from", msg.FromParty),
|
zap.String("from", msg.FromParty),
|
||||||
zap.Int("round", int(msg.RoundNumber)))
|
zap.Int("round", int(msg.RoundNumber)))
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
logger.Info("Subscribed to messages",
|
logger.Info("Subscribed to messages",
|
||||||
zap.String("session_id", sessionID.String()),
|
zap.String("session_id", sessionID.String()),
|
||||||
zap.String("party_id", partyID))
|
zap.String("party_id", partyID))
|
||||||
|
|
||||||
return msgChan, nil
|
return msgChan, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createSubscribeStream creates a streaming connection for message subscription
|
// createSubscribeStream creates a streaming connection for message subscription
|
||||||
func (c *MessageRouterClient) createSubscribeStream(
|
func (c *MessageRouterClient) createSubscribeStream(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *router.SubscribeMessagesRequest,
|
req *router.SubscribeMessagesRequest,
|
||||||
) (grpc.ClientStream, error) {
|
) (grpc.ClientStream, error) {
|
||||||
streamDesc := &grpc.StreamDesc{
|
streamDesc := &grpc.StreamDesc{
|
||||||
StreamName: "SubscribeMessages",
|
StreamName: "SubscribeMessages",
|
||||||
ServerStreams: true,
|
ServerStreams: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
stream, err := c.conn.NewStream(ctx, streamDesc, "/mpc.router.v1.MessageRouter/SubscribeMessages")
|
stream, err := c.conn.NewStream(ctx, streamDesc, "/mpc.router.v1.MessageRouter/SubscribeMessages")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := stream.SendMsg(req); err != nil {
|
if err := stream.SendMsg(req); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := stream.CloseSend(); err != nil {
|
if err := stream.CloseSend(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return stream, nil
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPendingMessages gets pending messages (polling alternative)
|
// GetPendingMessages gets pending messages (polling alternative)
|
||||||
func (c *MessageRouterClient) GetPendingMessages(
|
func (c *MessageRouterClient) GetPendingMessages(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID uuid.UUID,
|
sessionID uuid.UUID,
|
||||||
partyID string,
|
partyID string,
|
||||||
afterTimestamp int64,
|
afterTimestamp int64,
|
||||||
) ([]*use_cases.MPCMessage, error) {
|
) ([]*use_cases.MPCMessage, error) {
|
||||||
req := &router.GetPendingMessagesRequest{
|
req := &router.GetPendingMessagesRequest{
|
||||||
SessionId: sessionID.String(),
|
SessionId: sessionID.String(),
|
||||||
PartyId: partyID,
|
PartyId: partyID,
|
||||||
AfterTimestamp: afterTimestamp,
|
AfterTimestamp: afterTimestamp,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := &router.GetPendingMessagesResponse{}
|
resp := &router.GetPendingMessagesResponse{}
|
||||||
err := c.conn.Invoke(ctx, "/mpc.router.v1.MessageRouter/GetPendingMessages", req, resp)
|
err := c.conn.Invoke(ctx, "/mpc.router.v1.MessageRouter/GetPendingMessages", req, resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
messages := make([]*use_cases.MPCMessage, len(resp.Messages))
|
messages := make([]*use_cases.MPCMessage, len(resp.Messages))
|
||||||
for i, msg := range resp.Messages {
|
for i, msg := range resp.Messages {
|
||||||
messages[i] = &use_cases.MPCMessage{
|
messages[i] = &use_cases.MPCMessage{
|
||||||
FromParty: msg.FromParty,
|
FromParty: msg.FromParty,
|
||||||
IsBroadcast: msg.IsBroadcast,
|
IsBroadcast: msg.IsBroadcast,
|
||||||
RoundNumber: int(msg.RoundNumber),
|
RoundNumber: int(msg.RoundNumber),
|
||||||
Payload: msg.Payload,
|
Payload: msg.Payload,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,170 +1,170 @@
|
||||||
package postgres
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rwadurian/mpc-system/services/server-party/domain/entities"
|
"github.com/rwadurian/mpc-system/services/server-party/domain/entities"
|
||||||
"github.com/rwadurian/mpc-system/services/server-party/domain/repositories"
|
"github.com/rwadurian/mpc-system/services/server-party/domain/repositories"
|
||||||
)
|
)
|
||||||
|
|
||||||
// KeySharePostgresRepo implements KeyShareRepository for PostgreSQL
|
// KeySharePostgresRepo implements KeyShareRepository for PostgreSQL
|
||||||
type KeySharePostgresRepo struct {
|
type KeySharePostgresRepo struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewKeySharePostgresRepo creates a new PostgreSQL key share repository
|
// NewKeySharePostgresRepo creates a new PostgreSQL key share repository
|
||||||
func NewKeySharePostgresRepo(db *sql.DB) *KeySharePostgresRepo {
|
func NewKeySharePostgresRepo(db *sql.DB) *KeySharePostgresRepo {
|
||||||
return &KeySharePostgresRepo{db: db}
|
return &KeySharePostgresRepo{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save persists a new key share
|
// Save persists a new key share
|
||||||
func (r *KeySharePostgresRepo) Save(ctx context.Context, keyShare *entities.PartyKeyShare) error {
|
func (r *KeySharePostgresRepo) Save(ctx context.Context, keyShare *entities.PartyKeyShare) error {
|
||||||
_, err := r.db.ExecContext(ctx, `
|
_, err := r.db.ExecContext(ctx, `
|
||||||
INSERT INTO party_key_shares (
|
INSERT INTO party_key_shares (
|
||||||
id, party_id, party_index, session_id, threshold_n, threshold_t,
|
id, party_id, party_index, session_id, threshold_n, threshold_t,
|
||||||
share_data, public_key, created_at
|
share_data, public_key, created_at
|
||||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||||
`,
|
`,
|
||||||
keyShare.ID,
|
keyShare.ID,
|
||||||
keyShare.PartyID,
|
keyShare.PartyID,
|
||||||
keyShare.PartyIndex,
|
keyShare.PartyIndex,
|
||||||
keyShare.SessionID,
|
keyShare.SessionID,
|
||||||
keyShare.ThresholdN,
|
keyShare.ThresholdN,
|
||||||
keyShare.ThresholdT,
|
keyShare.ThresholdT,
|
||||||
keyShare.ShareData,
|
keyShare.ShareData,
|
||||||
keyShare.PublicKey,
|
keyShare.PublicKey,
|
||||||
keyShare.CreatedAt,
|
keyShare.CreatedAt,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindByID retrieves a key share by ID
|
// FindByID retrieves a key share by ID
|
||||||
func (r *KeySharePostgresRepo) FindByID(ctx context.Context, id uuid.UUID) (*entities.PartyKeyShare, error) {
|
func (r *KeySharePostgresRepo) FindByID(ctx context.Context, id uuid.UUID) (*entities.PartyKeyShare, error) {
|
||||||
var ks entities.PartyKeyShare
|
var ks entities.PartyKeyShare
|
||||||
err := r.db.QueryRowContext(ctx, `
|
err := r.db.QueryRowContext(ctx, `
|
||||||
SELECT id, party_id, party_index, session_id, threshold_n, threshold_t,
|
SELECT id, party_id, party_index, session_id, threshold_n, threshold_t,
|
||||||
share_data, public_key, created_at, last_used_at
|
share_data, public_key, created_at, last_used_at
|
||||||
FROM party_key_shares WHERE id = $1
|
FROM party_key_shares WHERE id = $1
|
||||||
`, id).Scan(
|
`, id).Scan(
|
||||||
&ks.ID,
|
&ks.ID,
|
||||||
&ks.PartyID,
|
&ks.PartyID,
|
||||||
&ks.PartyIndex,
|
&ks.PartyIndex,
|
||||||
&ks.SessionID,
|
&ks.SessionID,
|
||||||
&ks.ThresholdN,
|
&ks.ThresholdN,
|
||||||
&ks.ThresholdT,
|
&ks.ThresholdT,
|
||||||
&ks.ShareData,
|
&ks.ShareData,
|
||||||
&ks.PublicKey,
|
&ks.PublicKey,
|
||||||
&ks.CreatedAt,
|
&ks.CreatedAt,
|
||||||
&ks.LastUsedAt,
|
&ks.LastUsedAt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &ks, nil
|
return &ks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindBySessionAndParty retrieves a key share by session and party
|
// FindBySessionAndParty retrieves a key share by session and party
|
||||||
func (r *KeySharePostgresRepo) FindBySessionAndParty(ctx context.Context, sessionID uuid.UUID, partyID string) (*entities.PartyKeyShare, error) {
|
func (r *KeySharePostgresRepo) FindBySessionAndParty(ctx context.Context, sessionID uuid.UUID, partyID string) (*entities.PartyKeyShare, error) {
|
||||||
var ks entities.PartyKeyShare
|
var ks entities.PartyKeyShare
|
||||||
err := r.db.QueryRowContext(ctx, `
|
err := r.db.QueryRowContext(ctx, `
|
||||||
SELECT id, party_id, party_index, session_id, threshold_n, threshold_t,
|
SELECT id, party_id, party_index, session_id, threshold_n, threshold_t,
|
||||||
share_data, public_key, created_at, last_used_at
|
share_data, public_key, created_at, last_used_at
|
||||||
FROM party_key_shares WHERE session_id = $1 AND party_id = $2
|
FROM party_key_shares WHERE session_id = $1 AND party_id = $2
|
||||||
`, sessionID, partyID).Scan(
|
`, sessionID, partyID).Scan(
|
||||||
&ks.ID,
|
&ks.ID,
|
||||||
&ks.PartyID,
|
&ks.PartyID,
|
||||||
&ks.PartyIndex,
|
&ks.PartyIndex,
|
||||||
&ks.SessionID,
|
&ks.SessionID,
|
||||||
&ks.ThresholdN,
|
&ks.ThresholdN,
|
||||||
&ks.ThresholdT,
|
&ks.ThresholdT,
|
||||||
&ks.ShareData,
|
&ks.ShareData,
|
||||||
&ks.PublicKey,
|
&ks.PublicKey,
|
||||||
&ks.CreatedAt,
|
&ks.CreatedAt,
|
||||||
&ks.LastUsedAt,
|
&ks.LastUsedAt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &ks, nil
|
return &ks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindByPublicKey retrieves key shares by public key
|
// FindByPublicKey retrieves key shares by public key
|
||||||
func (r *KeySharePostgresRepo) FindByPublicKey(ctx context.Context, publicKey []byte) ([]*entities.PartyKeyShare, error) {
|
func (r *KeySharePostgresRepo) FindByPublicKey(ctx context.Context, publicKey []byte) ([]*entities.PartyKeyShare, error) {
|
||||||
rows, err := r.db.QueryContext(ctx, `
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
SELECT id, party_id, party_index, session_id, threshold_n, threshold_t,
|
SELECT id, party_id, party_index, session_id, threshold_n, threshold_t,
|
||||||
share_data, public_key, created_at, last_used_at
|
share_data, public_key, created_at, last_used_at
|
||||||
FROM party_key_shares WHERE public_key = $1
|
FROM party_key_shares WHERE public_key = $1
|
||||||
ORDER BY created_at DESC
|
ORDER BY created_at DESC
|
||||||
`, publicKey)
|
`, publicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
return r.scanKeyShares(rows)
|
return r.scanKeyShares(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update updates an existing key share
|
// Update updates an existing key share
|
||||||
func (r *KeySharePostgresRepo) Update(ctx context.Context, keyShare *entities.PartyKeyShare) error {
|
func (r *KeySharePostgresRepo) Update(ctx context.Context, keyShare *entities.PartyKeyShare) error {
|
||||||
_, err := r.db.ExecContext(ctx, `
|
_, err := r.db.ExecContext(ctx, `
|
||||||
UPDATE party_key_shares SET last_used_at = $1 WHERE id = $2
|
UPDATE party_key_shares SET last_used_at = $1 WHERE id = $2
|
||||||
`, keyShare.LastUsedAt, keyShare.ID)
|
`, keyShare.LastUsedAt, keyShare.ID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete removes a key share
|
// Delete removes a key share
|
||||||
func (r *KeySharePostgresRepo) Delete(ctx context.Context, id uuid.UUID) error {
|
func (r *KeySharePostgresRepo) Delete(ctx context.Context, id uuid.UUID) error {
|
||||||
_, err := r.db.ExecContext(ctx, `DELETE FROM party_key_shares WHERE id = $1`, id)
|
_, err := r.db.ExecContext(ctx, `DELETE FROM party_key_shares WHERE id = $1`, id)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListByParty lists all key shares for a party
|
// ListByParty lists all key shares for a party
|
||||||
func (r *KeySharePostgresRepo) ListByParty(ctx context.Context, partyID string) ([]*entities.PartyKeyShare, error) {
|
func (r *KeySharePostgresRepo) ListByParty(ctx context.Context, partyID string) ([]*entities.PartyKeyShare, error) {
|
||||||
rows, err := r.db.QueryContext(ctx, `
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
SELECT id, party_id, party_index, session_id, threshold_n, threshold_t,
|
SELECT id, party_id, party_index, session_id, threshold_n, threshold_t,
|
||||||
share_data, public_key, created_at, last_used_at
|
share_data, public_key, created_at, last_used_at
|
||||||
FROM party_key_shares WHERE party_id = $1
|
FROM party_key_shares WHERE party_id = $1
|
||||||
ORDER BY created_at DESC
|
ORDER BY created_at DESC
|
||||||
`, partyID)
|
`, partyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
return r.scanKeyShares(rows)
|
return r.scanKeyShares(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *KeySharePostgresRepo) scanKeyShares(rows *sql.Rows) ([]*entities.PartyKeyShare, error) {
|
func (r *KeySharePostgresRepo) scanKeyShares(rows *sql.Rows) ([]*entities.PartyKeyShare, error) {
|
||||||
var keyShares []*entities.PartyKeyShare
|
var keyShares []*entities.PartyKeyShare
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var ks entities.PartyKeyShare
|
var ks entities.PartyKeyShare
|
||||||
err := rows.Scan(
|
err := rows.Scan(
|
||||||
&ks.ID,
|
&ks.ID,
|
||||||
&ks.PartyID,
|
&ks.PartyID,
|
||||||
&ks.PartyIndex,
|
&ks.PartyIndex,
|
||||||
&ks.SessionID,
|
&ks.SessionID,
|
||||||
&ks.ThresholdN,
|
&ks.ThresholdN,
|
||||||
&ks.ThresholdT,
|
&ks.ThresholdT,
|
||||||
&ks.ShareData,
|
&ks.ShareData,
|
||||||
&ks.PublicKey,
|
&ks.PublicKey,
|
||||||
&ks.CreatedAt,
|
&ks.CreatedAt,
|
||||||
&ks.LastUsedAt,
|
&ks.LastUsedAt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
keyShares = append(keyShares, &ks)
|
keyShares = append(keyShares, &ks)
|
||||||
}
|
}
|
||||||
return keyShares, rows.Err()
|
return keyShares, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure interface compliance
|
// Ensure interface compliance
|
||||||
var _ repositories.KeyShareRepository = (*KeySharePostgresRepo)(nil)
|
var _ repositories.KeyShareRepository = (*KeySharePostgresRepo)(nil)
|
||||||
|
|
|
||||||
|
|
@ -1,294 +1,294 @@
|
||||||
package use_cases
|
package use_cases
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rwadurian/mpc-system/pkg/crypto"
|
"github.com/rwadurian/mpc-system/pkg/crypto"
|
||||||
"github.com/rwadurian/mpc-system/pkg/logger"
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
||||||
"github.com/rwadurian/mpc-system/pkg/tss"
|
"github.com/rwadurian/mpc-system/pkg/tss"
|
||||||
"github.com/rwadurian/mpc-system/services/server-party/domain/entities"
|
"github.com/rwadurian/mpc-system/services/server-party/domain/entities"
|
||||||
"github.com/rwadurian/mpc-system/services/server-party/domain/repositories"
|
"github.com/rwadurian/mpc-system/services/server-party/domain/repositories"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrKeygenFailed = errors.New("keygen failed")
|
ErrKeygenFailed = errors.New("keygen failed")
|
||||||
ErrKeygenTimeout = errors.New("keygen timeout")
|
ErrKeygenTimeout = errors.New("keygen timeout")
|
||||||
ErrInvalidSession = errors.New("invalid session")
|
ErrInvalidSession = errors.New("invalid session")
|
||||||
ErrShareSaveFailed = errors.New("failed to save share")
|
ErrShareSaveFailed = errors.New("failed to save share")
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParticipateKeygenInput contains input for keygen participation
|
// ParticipateKeygenInput contains input for keygen participation
|
||||||
type ParticipateKeygenInput struct {
|
type ParticipateKeygenInput struct {
|
||||||
SessionID uuid.UUID
|
SessionID uuid.UUID
|
||||||
PartyID string
|
PartyID string
|
||||||
JoinToken string
|
JoinToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParticipateKeygenOutput contains output from keygen participation
|
// ParticipateKeygenOutput contains output from keygen participation
|
||||||
type ParticipateKeygenOutput struct {
|
type ParticipateKeygenOutput struct {
|
||||||
Success bool
|
Success bool
|
||||||
KeyShare *entities.PartyKeyShare
|
KeyShare *entities.PartyKeyShare
|
||||||
PublicKey []byte
|
PublicKey []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionCoordinatorClient defines the interface for session coordinator communication
|
// SessionCoordinatorClient defines the interface for session coordinator communication
|
||||||
type SessionCoordinatorClient interface {
|
type SessionCoordinatorClient interface {
|
||||||
JoinSession(ctx context.Context, sessionID uuid.UUID, partyID, joinToken string) (*SessionInfo, error)
|
JoinSession(ctx context.Context, sessionID uuid.UUID, partyID, joinToken string) (*SessionInfo, error)
|
||||||
ReportCompletion(ctx context.Context, sessionID uuid.UUID, partyID string, publicKey []byte) error
|
ReportCompletion(ctx context.Context, sessionID uuid.UUID, partyID string, publicKey []byte) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// MessageRouterClient defines the interface for message router communication
|
// MessageRouterClient defines the interface for message router communication
|
||||||
type MessageRouterClient interface {
|
type MessageRouterClient interface {
|
||||||
RouteMessage(ctx context.Context, sessionID uuid.UUID, fromParty string, toParties []string, roundNumber int, payload []byte) error
|
RouteMessage(ctx context.Context, sessionID uuid.UUID, fromParty string, toParties []string, roundNumber int, payload []byte) error
|
||||||
SubscribeMessages(ctx context.Context, sessionID uuid.UUID, partyID string) (<-chan *MPCMessage, error)
|
SubscribeMessages(ctx context.Context, sessionID uuid.UUID, partyID string) (<-chan *MPCMessage, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionInfo contains session information from coordinator
|
// SessionInfo contains session information from coordinator
|
||||||
type SessionInfo struct {
|
type SessionInfo struct {
|
||||||
SessionID uuid.UUID
|
SessionID uuid.UUID
|
||||||
SessionType string
|
SessionType string
|
||||||
ThresholdN int
|
ThresholdN int
|
||||||
ThresholdT int
|
ThresholdT int
|
||||||
MessageHash []byte
|
MessageHash []byte
|
||||||
Participants []ParticipantInfo
|
Participants []ParticipantInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParticipantInfo contains participant information
|
// ParticipantInfo contains participant information
|
||||||
type ParticipantInfo struct {
|
type ParticipantInfo struct {
|
||||||
PartyID string
|
PartyID string
|
||||||
PartyIndex int
|
PartyIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
// MPCMessage represents an MPC message from the router
|
// MPCMessage represents an MPC message from the router
|
||||||
type MPCMessage struct {
|
type MPCMessage struct {
|
||||||
FromParty string
|
FromParty string
|
||||||
IsBroadcast bool
|
IsBroadcast bool
|
||||||
RoundNumber int
|
RoundNumber int
|
||||||
Payload []byte
|
Payload []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParticipateKeygenUseCase handles keygen participation
|
// ParticipateKeygenUseCase handles keygen participation
|
||||||
type ParticipateKeygenUseCase struct {
|
type ParticipateKeygenUseCase struct {
|
||||||
keyShareRepo repositories.KeyShareRepository
|
keyShareRepo repositories.KeyShareRepository
|
||||||
sessionClient SessionCoordinatorClient
|
sessionClient SessionCoordinatorClient
|
||||||
messageRouter MessageRouterClient
|
messageRouter MessageRouterClient
|
||||||
cryptoService *crypto.CryptoService
|
cryptoService *crypto.CryptoService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewParticipateKeygenUseCase creates a new participate keygen use case
|
// NewParticipateKeygenUseCase creates a new participate keygen use case
|
||||||
func NewParticipateKeygenUseCase(
|
func NewParticipateKeygenUseCase(
|
||||||
keyShareRepo repositories.KeyShareRepository,
|
keyShareRepo repositories.KeyShareRepository,
|
||||||
sessionClient SessionCoordinatorClient,
|
sessionClient SessionCoordinatorClient,
|
||||||
messageRouter MessageRouterClient,
|
messageRouter MessageRouterClient,
|
||||||
cryptoService *crypto.CryptoService,
|
cryptoService *crypto.CryptoService,
|
||||||
) *ParticipateKeygenUseCase {
|
) *ParticipateKeygenUseCase {
|
||||||
return &ParticipateKeygenUseCase{
|
return &ParticipateKeygenUseCase{
|
||||||
keyShareRepo: keyShareRepo,
|
keyShareRepo: keyShareRepo,
|
||||||
sessionClient: sessionClient,
|
sessionClient: sessionClient,
|
||||||
messageRouter: messageRouter,
|
messageRouter: messageRouter,
|
||||||
cryptoService: cryptoService,
|
cryptoService: cryptoService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute participates in a keygen session using real TSS protocol
|
// Execute participates in a keygen session using real TSS protocol
|
||||||
func (uc *ParticipateKeygenUseCase) Execute(
|
func (uc *ParticipateKeygenUseCase) Execute(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
input ParticipateKeygenInput,
|
input ParticipateKeygenInput,
|
||||||
) (*ParticipateKeygenOutput, error) {
|
) (*ParticipateKeygenOutput, error) {
|
||||||
// 1. Join session via coordinator
|
// 1. Join session via coordinator
|
||||||
sessionInfo, err := uc.sessionClient.JoinSession(ctx, input.SessionID, input.PartyID, input.JoinToken)
|
sessionInfo, err := uc.sessionClient.JoinSession(ctx, input.SessionID, input.PartyID, input.JoinToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if sessionInfo.SessionType != "keygen" {
|
if sessionInfo.SessionType != "keygen" {
|
||||||
return nil, ErrInvalidSession
|
return nil, ErrInvalidSession
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Find self in participants and build party index map
|
// 2. Find self in participants and build party index map
|
||||||
var selfIndex int
|
var selfIndex int
|
||||||
partyIndexMap := make(map[string]int)
|
partyIndexMap := make(map[string]int)
|
||||||
for _, p := range sessionInfo.Participants {
|
for _, p := range sessionInfo.Participants {
|
||||||
partyIndexMap[p.PartyID] = p.PartyIndex
|
partyIndexMap[p.PartyID] = p.PartyIndex
|
||||||
if p.PartyID == input.PartyID {
|
if p.PartyID == input.PartyID {
|
||||||
selfIndex = p.PartyIndex
|
selfIndex = p.PartyIndex
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Subscribe to messages
|
// 3. Subscribe to messages
|
||||||
msgChan, err := uc.messageRouter.SubscribeMessages(ctx, input.SessionID, input.PartyID)
|
msgChan, err := uc.messageRouter.SubscribeMessages(ctx, input.SessionID, input.PartyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Run TSS Keygen protocol
|
// 4. Run TSS Keygen protocol
|
||||||
saveData, publicKey, err := uc.runKeygenProtocol(
|
saveData, publicKey, err := uc.runKeygenProtocol(
|
||||||
ctx,
|
ctx,
|
||||||
input.SessionID,
|
input.SessionID,
|
||||||
input.PartyID,
|
input.PartyID,
|
||||||
selfIndex,
|
selfIndex,
|
||||||
sessionInfo.Participants,
|
sessionInfo.Participants,
|
||||||
sessionInfo.ThresholdN,
|
sessionInfo.ThresholdN,
|
||||||
sessionInfo.ThresholdT,
|
sessionInfo.ThresholdT,
|
||||||
msgChan,
|
msgChan,
|
||||||
partyIndexMap,
|
partyIndexMap,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. Encrypt and save the share
|
// 5. Encrypt and save the share
|
||||||
encryptedShare, err := uc.cryptoService.EncryptShare(saveData, input.PartyID)
|
encryptedShare, err := uc.cryptoService.EncryptShare(saveData, input.PartyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
keyShare := entities.NewPartyKeyShare(
|
keyShare := entities.NewPartyKeyShare(
|
||||||
input.PartyID,
|
input.PartyID,
|
||||||
selfIndex,
|
selfIndex,
|
||||||
input.SessionID,
|
input.SessionID,
|
||||||
sessionInfo.ThresholdN,
|
sessionInfo.ThresholdN,
|
||||||
sessionInfo.ThresholdT,
|
sessionInfo.ThresholdT,
|
||||||
encryptedShare,
|
encryptedShare,
|
||||||
publicKey,
|
publicKey,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := uc.keyShareRepo.Save(ctx, keyShare); err != nil {
|
if err := uc.keyShareRepo.Save(ctx, keyShare); err != nil {
|
||||||
return nil, ErrShareSaveFailed
|
return nil, ErrShareSaveFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6. Report completion to coordinator
|
// 6. Report completion to coordinator
|
||||||
if err := uc.sessionClient.ReportCompletion(ctx, input.SessionID, input.PartyID, publicKey); err != nil {
|
if err := uc.sessionClient.ReportCompletion(ctx, input.SessionID, input.PartyID, publicKey); err != nil {
|
||||||
logger.Error("failed to report completion", zap.Error(err))
|
logger.Error("failed to report completion", zap.Error(err))
|
||||||
// Don't fail - share is saved
|
// Don't fail - share is saved
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ParticipateKeygenOutput{
|
return &ParticipateKeygenOutput{
|
||||||
Success: true,
|
Success: true,
|
||||||
KeyShare: keyShare,
|
KeyShare: keyShare,
|
||||||
PublicKey: publicKey,
|
PublicKey: publicKey,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// runKeygenProtocol runs the TSS keygen protocol using tss-lib
|
// runKeygenProtocol runs the TSS keygen protocol using tss-lib
|
||||||
func (uc *ParticipateKeygenUseCase) runKeygenProtocol(
|
func (uc *ParticipateKeygenUseCase) runKeygenProtocol(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID uuid.UUID,
|
sessionID uuid.UUID,
|
||||||
partyID string,
|
partyID string,
|
||||||
selfIndex int,
|
selfIndex int,
|
||||||
participants []ParticipantInfo,
|
participants []ParticipantInfo,
|
||||||
n, t int,
|
n, t int,
|
||||||
msgChan <-chan *MPCMessage,
|
msgChan <-chan *MPCMessage,
|
||||||
partyIndexMap map[string]int,
|
partyIndexMap map[string]int,
|
||||||
) ([]byte, []byte, error) {
|
) ([]byte, []byte, error) {
|
||||||
logger.Info("Running keygen protocol",
|
logger.Info("Running keygen protocol",
|
||||||
zap.String("session_id", sessionID.String()),
|
zap.String("session_id", sessionID.String()),
|
||||||
zap.String("party_id", partyID),
|
zap.String("party_id", partyID),
|
||||||
zap.Int("self_index", selfIndex),
|
zap.Int("self_index", selfIndex),
|
||||||
zap.Int("n", n),
|
zap.Int("n", n),
|
||||||
zap.Int("t", t))
|
zap.Int("t", t))
|
||||||
|
|
||||||
// Create message handler adapter
|
// Create message handler adapter
|
||||||
msgHandler := &keygenMessageHandler{
|
msgHandler := &keygenMessageHandler{
|
||||||
sessionID: sessionID,
|
sessionID: sessionID,
|
||||||
partyID: partyID,
|
partyID: partyID,
|
||||||
messageRouter: uc.messageRouter,
|
messageRouter: uc.messageRouter,
|
||||||
msgChan: make(chan *tss.ReceivedMessage, 100),
|
msgChan: make(chan *tss.ReceivedMessage, 100),
|
||||||
partyIndexMap: partyIndexMap,
|
partyIndexMap: partyIndexMap,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start message conversion goroutine
|
// Start message conversion goroutine
|
||||||
go msgHandler.convertMessages(ctx, msgChan)
|
go msgHandler.convertMessages(ctx, msgChan)
|
||||||
|
|
||||||
// Create keygen config
|
// Create keygen config
|
||||||
config := tss.KeygenConfig{
|
config := tss.KeygenConfig{
|
||||||
Threshold: t,
|
Threshold: t,
|
||||||
TotalParties: n,
|
TotalParties: n,
|
||||||
Timeout: 10 * time.Minute,
|
Timeout: 10 * time.Minute,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create party list
|
// Create party list
|
||||||
allParties := make([]tss.KeygenParty, len(participants))
|
allParties := make([]tss.KeygenParty, len(participants))
|
||||||
for i, p := range participants {
|
for i, p := range participants {
|
||||||
allParties[i] = tss.KeygenParty{
|
allParties[i] = tss.KeygenParty{
|
||||||
PartyID: p.PartyID,
|
PartyID: p.PartyID,
|
||||||
PartyIndex: p.PartyIndex,
|
PartyIndex: p.PartyIndex,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
selfParty := tss.KeygenParty{
|
selfParty := tss.KeygenParty{
|
||||||
PartyID: partyID,
|
PartyID: partyID,
|
||||||
PartyIndex: selfIndex,
|
PartyIndex: selfIndex,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create keygen session
|
// Create keygen session
|
||||||
session, err := tss.NewKeygenSession(config, selfParty, allParties, msgHandler)
|
session, err := tss.NewKeygenSession(config, selfParty, allParties, msgHandler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run keygen
|
// Run keygen
|
||||||
result, err := session.Start(ctx)
|
result, err := session.Start(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Keygen completed successfully",
|
logger.Info("Keygen completed successfully",
|
||||||
zap.String("session_id", sessionID.String()),
|
zap.String("session_id", sessionID.String()),
|
||||||
zap.String("party_id", partyID))
|
zap.String("party_id", partyID))
|
||||||
|
|
||||||
return result.LocalPartySaveData, result.PublicKeyBytes, nil
|
return result.LocalPartySaveData, result.PublicKeyBytes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// keygenMessageHandler adapts MPCMessage channel to tss.MessageHandler
|
// keygenMessageHandler adapts MPCMessage channel to tss.MessageHandler
|
||||||
type keygenMessageHandler struct {
|
type keygenMessageHandler struct {
|
||||||
sessionID uuid.UUID
|
sessionID uuid.UUID
|
||||||
partyID string
|
partyID string
|
||||||
messageRouter MessageRouterClient
|
messageRouter MessageRouterClient
|
||||||
msgChan chan *tss.ReceivedMessage
|
msgChan chan *tss.ReceivedMessage
|
||||||
partyIndexMap map[string]int
|
partyIndexMap map[string]int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *keygenMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error {
|
func (h *keygenMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error {
|
||||||
return h.messageRouter.RouteMessage(ctx, h.sessionID, h.partyID, toParties, 0, msgBytes)
|
return h.messageRouter.RouteMessage(ctx, h.sessionID, h.partyID, toParties, 0, msgBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *keygenMessageHandler) ReceiveMessages() <-chan *tss.ReceivedMessage {
|
func (h *keygenMessageHandler) ReceiveMessages() <-chan *tss.ReceivedMessage {
|
||||||
return h.msgChan
|
return h.msgChan
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *keygenMessageHandler) convertMessages(ctx context.Context, inChan <-chan *MPCMessage) {
|
func (h *keygenMessageHandler) convertMessages(ctx context.Context, inChan <-chan *MPCMessage) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
close(h.msgChan)
|
close(h.msgChan)
|
||||||
return
|
return
|
||||||
case msg, ok := <-inChan:
|
case msg, ok := <-inChan:
|
||||||
if !ok {
|
if !ok {
|
||||||
close(h.msgChan)
|
close(h.msgChan)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fromIndex, exists := h.partyIndexMap[msg.FromParty]
|
fromIndex, exists := h.partyIndexMap[msg.FromParty]
|
||||||
if !exists {
|
if !exists {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
tssMsg := &tss.ReceivedMessage{
|
tssMsg := &tss.ReceivedMessage{
|
||||||
FromPartyIndex: fromIndex,
|
FromPartyIndex: fromIndex,
|
||||||
IsBroadcast: msg.IsBroadcast,
|
IsBroadcast: msg.IsBroadcast,
|
||||||
MsgBytes: msg.Payload,
|
MsgBytes: msg.Payload,
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case h.msgChan <- tssMsg:
|
case h.msgChan <- tssMsg:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,270 +1,270 @@
|
||||||
package use_cases
|
package use_cases
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"math/big"
|
"math/big"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rwadurian/mpc-system/pkg/crypto"
|
"github.com/rwadurian/mpc-system/pkg/crypto"
|
||||||
"github.com/rwadurian/mpc-system/pkg/logger"
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
||||||
"github.com/rwadurian/mpc-system/pkg/tss"
|
"github.com/rwadurian/mpc-system/pkg/tss"
|
||||||
"github.com/rwadurian/mpc-system/services/server-party/domain/repositories"
|
"github.com/rwadurian/mpc-system/services/server-party/domain/repositories"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrSigningFailed = errors.New("signing failed")
|
ErrSigningFailed = errors.New("signing failed")
|
||||||
ErrSigningTimeout = errors.New("signing timeout")
|
ErrSigningTimeout = errors.New("signing timeout")
|
||||||
ErrKeyShareNotFound = errors.New("key share not found")
|
ErrKeyShareNotFound = errors.New("key share not found")
|
||||||
ErrInvalidSignSession = errors.New("invalid sign session")
|
ErrInvalidSignSession = errors.New("invalid sign session")
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParticipateSigningInput contains input for signing participation
|
// ParticipateSigningInput contains input for signing participation
|
||||||
type ParticipateSigningInput struct {
|
type ParticipateSigningInput struct {
|
||||||
SessionID uuid.UUID
|
SessionID uuid.UUID
|
||||||
PartyID string
|
PartyID string
|
||||||
JoinToken string
|
JoinToken string
|
||||||
MessageHash []byte
|
MessageHash []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParticipateSigningOutput contains output from signing participation
|
// ParticipateSigningOutput contains output from signing participation
|
||||||
type ParticipateSigningOutput struct {
|
type ParticipateSigningOutput struct {
|
||||||
Success bool
|
Success bool
|
||||||
Signature []byte
|
Signature []byte
|
||||||
R *big.Int
|
R *big.Int
|
||||||
S *big.Int
|
S *big.Int
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParticipateSigningUseCase handles signing participation
|
// ParticipateSigningUseCase handles signing participation
|
||||||
type ParticipateSigningUseCase struct {
|
type ParticipateSigningUseCase struct {
|
||||||
keyShareRepo repositories.KeyShareRepository
|
keyShareRepo repositories.KeyShareRepository
|
||||||
sessionClient SessionCoordinatorClient
|
sessionClient SessionCoordinatorClient
|
||||||
messageRouter MessageRouterClient
|
messageRouter MessageRouterClient
|
||||||
cryptoService *crypto.CryptoService
|
cryptoService *crypto.CryptoService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewParticipateSigningUseCase creates a new participate signing use case
|
// NewParticipateSigningUseCase creates a new participate signing use case
|
||||||
func NewParticipateSigningUseCase(
|
func NewParticipateSigningUseCase(
|
||||||
keyShareRepo repositories.KeyShareRepository,
|
keyShareRepo repositories.KeyShareRepository,
|
||||||
sessionClient SessionCoordinatorClient,
|
sessionClient SessionCoordinatorClient,
|
||||||
messageRouter MessageRouterClient,
|
messageRouter MessageRouterClient,
|
||||||
cryptoService *crypto.CryptoService,
|
cryptoService *crypto.CryptoService,
|
||||||
) *ParticipateSigningUseCase {
|
) *ParticipateSigningUseCase {
|
||||||
return &ParticipateSigningUseCase{
|
return &ParticipateSigningUseCase{
|
||||||
keyShareRepo: keyShareRepo,
|
keyShareRepo: keyShareRepo,
|
||||||
sessionClient: sessionClient,
|
sessionClient: sessionClient,
|
||||||
messageRouter: messageRouter,
|
messageRouter: messageRouter,
|
||||||
cryptoService: cryptoService,
|
cryptoService: cryptoService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute participates in a signing session using real TSS protocol
|
// Execute participates in a signing session using real TSS protocol
|
||||||
func (uc *ParticipateSigningUseCase) Execute(
|
func (uc *ParticipateSigningUseCase) Execute(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
input ParticipateSigningInput,
|
input ParticipateSigningInput,
|
||||||
) (*ParticipateSigningOutput, error) {
|
) (*ParticipateSigningOutput, error) {
|
||||||
// 1. Join session via coordinator
|
// 1. Join session via coordinator
|
||||||
sessionInfo, err := uc.sessionClient.JoinSession(ctx, input.SessionID, input.PartyID, input.JoinToken)
|
sessionInfo, err := uc.sessionClient.JoinSession(ctx, input.SessionID, input.PartyID, input.JoinToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if sessionInfo.SessionType != "sign" {
|
if sessionInfo.SessionType != "sign" {
|
||||||
return nil, ErrInvalidSignSession
|
return nil, ErrInvalidSignSession
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Load key share for this party
|
// 2. Load key share for this party
|
||||||
keyShares, err := uc.keyShareRepo.ListByParty(ctx, input.PartyID)
|
keyShares, err := uc.keyShareRepo.ListByParty(ctx, input.PartyID)
|
||||||
if err != nil || len(keyShares) == 0 {
|
if err != nil || len(keyShares) == 0 {
|
||||||
return nil, ErrKeyShareNotFound
|
return nil, ErrKeyShareNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use the most recent key share (in production, would match by public key or session reference)
|
// Use the most recent key share (in production, would match by public key or session reference)
|
||||||
keyShare := keyShares[len(keyShares)-1]
|
keyShare := keyShares[len(keyShares)-1]
|
||||||
|
|
||||||
// 3. Decrypt share data
|
// 3. Decrypt share data
|
||||||
shareData, err := uc.cryptoService.DecryptShare(keyShare.ShareData, input.PartyID)
|
shareData, err := uc.cryptoService.DecryptShare(keyShare.ShareData, input.PartyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Find self in participants and build party index map
|
// 4. Find self in participants and build party index map
|
||||||
var selfIndex int
|
var selfIndex int
|
||||||
partyIndexMap := make(map[string]int)
|
partyIndexMap := make(map[string]int)
|
||||||
for _, p := range sessionInfo.Participants {
|
for _, p := range sessionInfo.Participants {
|
||||||
partyIndexMap[p.PartyID] = p.PartyIndex
|
partyIndexMap[p.PartyID] = p.PartyIndex
|
||||||
if p.PartyID == input.PartyID {
|
if p.PartyID == input.PartyID {
|
||||||
selfIndex = p.PartyIndex
|
selfIndex = p.PartyIndex
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. Subscribe to messages
|
// 5. Subscribe to messages
|
||||||
msgChan, err := uc.messageRouter.SubscribeMessages(ctx, input.SessionID, input.PartyID)
|
msgChan, err := uc.messageRouter.SubscribeMessages(ctx, input.SessionID, input.PartyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use message hash from session if not provided
|
// Use message hash from session if not provided
|
||||||
messageHash := input.MessageHash
|
messageHash := input.MessageHash
|
||||||
if len(messageHash) == 0 {
|
if len(messageHash) == 0 {
|
||||||
messageHash = sessionInfo.MessageHash
|
messageHash = sessionInfo.MessageHash
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6. Run TSS Signing protocol
|
// 6. Run TSS Signing protocol
|
||||||
signature, r, s, err := uc.runSigningProtocol(
|
signature, r, s, err := uc.runSigningProtocol(
|
||||||
ctx,
|
ctx,
|
||||||
input.SessionID,
|
input.SessionID,
|
||||||
input.PartyID,
|
input.PartyID,
|
||||||
selfIndex,
|
selfIndex,
|
||||||
sessionInfo.Participants,
|
sessionInfo.Participants,
|
||||||
sessionInfo.ThresholdT,
|
sessionInfo.ThresholdT,
|
||||||
shareData,
|
shareData,
|
||||||
messageHash,
|
messageHash,
|
||||||
msgChan,
|
msgChan,
|
||||||
partyIndexMap,
|
partyIndexMap,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 7. Update key share last used
|
// 7. Update key share last used
|
||||||
keyShare.MarkUsed()
|
keyShare.MarkUsed()
|
||||||
if err := uc.keyShareRepo.Update(ctx, keyShare); err != nil {
|
if err := uc.keyShareRepo.Update(ctx, keyShare); err != nil {
|
||||||
logger.Warn("failed to update key share last used", zap.Error(err))
|
logger.Warn("failed to update key share last used", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 8. Report completion to coordinator
|
// 8. Report completion to coordinator
|
||||||
if err := uc.sessionClient.ReportCompletion(ctx, input.SessionID, input.PartyID, signature); err != nil {
|
if err := uc.sessionClient.ReportCompletion(ctx, input.SessionID, input.PartyID, signature); err != nil {
|
||||||
logger.Error("failed to report signing completion", zap.Error(err))
|
logger.Error("failed to report signing completion", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ParticipateSigningOutput{
|
return &ParticipateSigningOutput{
|
||||||
Success: true,
|
Success: true,
|
||||||
Signature: signature,
|
Signature: signature,
|
||||||
R: r,
|
R: r,
|
||||||
S: s,
|
S: s,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// runSigningProtocol runs the TSS signing protocol using tss-lib
|
// runSigningProtocol runs the TSS signing protocol using tss-lib
|
||||||
func (uc *ParticipateSigningUseCase) runSigningProtocol(
|
func (uc *ParticipateSigningUseCase) runSigningProtocol(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID uuid.UUID,
|
sessionID uuid.UUID,
|
||||||
partyID string,
|
partyID string,
|
||||||
selfIndex int,
|
selfIndex int,
|
||||||
participants []ParticipantInfo,
|
participants []ParticipantInfo,
|
||||||
t int,
|
t int,
|
||||||
shareData []byte,
|
shareData []byte,
|
||||||
messageHash []byte,
|
messageHash []byte,
|
||||||
msgChan <-chan *MPCMessage,
|
msgChan <-chan *MPCMessage,
|
||||||
partyIndexMap map[string]int,
|
partyIndexMap map[string]int,
|
||||||
) ([]byte, *big.Int, *big.Int, error) {
|
) ([]byte, *big.Int, *big.Int, error) {
|
||||||
logger.Info("Running signing protocol",
|
logger.Info("Running signing protocol",
|
||||||
zap.String("session_id", sessionID.String()),
|
zap.String("session_id", sessionID.String()),
|
||||||
zap.String("party_id", partyID),
|
zap.String("party_id", partyID),
|
||||||
zap.Int("self_index", selfIndex),
|
zap.Int("self_index", selfIndex),
|
||||||
zap.Int("t", t),
|
zap.Int("t", t),
|
||||||
zap.Int("message_hash_len", len(messageHash)))
|
zap.Int("message_hash_len", len(messageHash)))
|
||||||
|
|
||||||
// Create message handler adapter
|
// Create message handler adapter
|
||||||
msgHandler := &signingMessageHandler{
|
msgHandler := &signingMessageHandler{
|
||||||
sessionID: sessionID,
|
sessionID: sessionID,
|
||||||
partyID: partyID,
|
partyID: partyID,
|
||||||
messageRouter: uc.messageRouter,
|
messageRouter: uc.messageRouter,
|
||||||
msgChan: make(chan *tss.ReceivedMessage, 100),
|
msgChan: make(chan *tss.ReceivedMessage, 100),
|
||||||
partyIndexMap: partyIndexMap,
|
partyIndexMap: partyIndexMap,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start message conversion goroutine
|
// Start message conversion goroutine
|
||||||
go msgHandler.convertMessages(ctx, msgChan)
|
go msgHandler.convertMessages(ctx, msgChan)
|
||||||
|
|
||||||
// Create signing config
|
// Create signing config
|
||||||
config := tss.SigningConfig{
|
config := tss.SigningConfig{
|
||||||
Threshold: t,
|
Threshold: t,
|
||||||
TotalSigners: len(participants),
|
TotalSigners: len(participants),
|
||||||
Timeout: 5 * time.Minute,
|
Timeout: 5 * time.Minute,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create party list
|
// Create party list
|
||||||
allParties := make([]tss.SigningParty, len(participants))
|
allParties := make([]tss.SigningParty, len(participants))
|
||||||
for i, p := range participants {
|
for i, p := range participants {
|
||||||
allParties[i] = tss.SigningParty{
|
allParties[i] = tss.SigningParty{
|
||||||
PartyID: p.PartyID,
|
PartyID: p.PartyID,
|
||||||
PartyIndex: p.PartyIndex,
|
PartyIndex: p.PartyIndex,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
selfParty := tss.SigningParty{
|
selfParty := tss.SigningParty{
|
||||||
PartyID: partyID,
|
PartyID: partyID,
|
||||||
PartyIndex: selfIndex,
|
PartyIndex: selfIndex,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create signing session
|
// Create signing session
|
||||||
session, err := tss.NewSigningSession(config, selfParty, allParties, messageHash, shareData, msgHandler)
|
session, err := tss.NewSigningSession(config, selfParty, allParties, messageHash, shareData, msgHandler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run signing
|
// Run signing
|
||||||
result, err := session.Start(ctx)
|
result, err := session.Start(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Signing completed successfully",
|
logger.Info("Signing completed successfully",
|
||||||
zap.String("session_id", sessionID.String()),
|
zap.String("session_id", sessionID.String()),
|
||||||
zap.String("party_id", partyID))
|
zap.String("party_id", partyID))
|
||||||
|
|
||||||
return result.Signature, result.R, result.S, nil
|
return result.Signature, result.R, result.S, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// signingMessageHandler adapts MPCMessage channel to tss.MessageHandler
|
// signingMessageHandler adapts MPCMessage channel to tss.MessageHandler
|
||||||
type signingMessageHandler struct {
|
type signingMessageHandler struct {
|
||||||
sessionID uuid.UUID
|
sessionID uuid.UUID
|
||||||
partyID string
|
partyID string
|
||||||
messageRouter MessageRouterClient
|
messageRouter MessageRouterClient
|
||||||
msgChan chan *tss.ReceivedMessage
|
msgChan chan *tss.ReceivedMessage
|
||||||
partyIndexMap map[string]int
|
partyIndexMap map[string]int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *signingMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error {
|
func (h *signingMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error {
|
||||||
return h.messageRouter.RouteMessage(ctx, h.sessionID, h.partyID, toParties, 0, msgBytes)
|
return h.messageRouter.RouteMessage(ctx, h.sessionID, h.partyID, toParties, 0, msgBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *signingMessageHandler) ReceiveMessages() <-chan *tss.ReceivedMessage {
|
func (h *signingMessageHandler) ReceiveMessages() <-chan *tss.ReceivedMessage {
|
||||||
return h.msgChan
|
return h.msgChan
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *signingMessageHandler) convertMessages(ctx context.Context, inChan <-chan *MPCMessage) {
|
func (h *signingMessageHandler) convertMessages(ctx context.Context, inChan <-chan *MPCMessage) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
close(h.msgChan)
|
close(h.msgChan)
|
||||||
return
|
return
|
||||||
case msg, ok := <-inChan:
|
case msg, ok := <-inChan:
|
||||||
if !ok {
|
if !ok {
|
||||||
close(h.msgChan)
|
close(h.msgChan)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fromIndex, exists := h.partyIndexMap[msg.FromParty]
|
fromIndex, exists := h.partyIndexMap[msg.FromParty]
|
||||||
if !exists {
|
if !exists {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
tssMsg := &tss.ReceivedMessage{
|
tssMsg := &tss.ReceivedMessage{
|
||||||
FromPartyIndex: fromIndex,
|
FromPartyIndex: fromIndex,
|
||||||
IsBroadcast: msg.IsBroadcast,
|
IsBroadcast: msg.IsBroadcast,
|
||||||
MsgBytes: msg.Payload,
|
MsgBytes: msg.Payload,
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case h.msgChan <- tssMsg:
|
case h.msgChan <- tssMsg:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,344 +1,382 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
|
|
||||||
"github.com/rwadurian/mpc-system/pkg/config"
|
"github.com/rwadurian/mpc-system/pkg/config"
|
||||||
"github.com/rwadurian/mpc-system/pkg/crypto"
|
"github.com/rwadurian/mpc-system/pkg/crypto"
|
||||||
"github.com/rwadurian/mpc-system/pkg/logger"
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
||||||
grpcclient "github.com/rwadurian/mpc-system/services/server-party/adapters/output/grpc"
|
grpcclient "github.com/rwadurian/mpc-system/services/server-party/adapters/output/grpc"
|
||||||
"github.com/rwadurian/mpc-system/services/server-party/adapters/output/postgres"
|
"github.com/rwadurian/mpc-system/services/server-party/adapters/output/postgres"
|
||||||
"github.com/rwadurian/mpc-system/services/server-party/application/use_cases"
|
"github.com/rwadurian/mpc-system/services/server-party/application/use_cases"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Parse flags
|
// Parse flags
|
||||||
configPath := flag.String("config", "", "Path to config file")
|
configPath := flag.String("config", "", "Path to config file")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
// Load configuration
|
// Load configuration
|
||||||
cfg, err := config.Load(*configPath)
|
cfg, err := config.Load(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Failed to load config: %v\n", err)
|
fmt.Printf("Failed to load config: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize logger
|
// Initialize logger
|
||||||
if err := logger.Init(&logger.Config{
|
if err := logger.Init(&logger.Config{
|
||||||
Level: cfg.Logger.Level,
|
Level: cfg.Logger.Level,
|
||||||
Encoding: cfg.Logger.Encoding,
|
Encoding: cfg.Logger.Encoding,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
fmt.Printf("Failed to initialize logger: %v\n", err)
|
fmt.Printf("Failed to initialize logger: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
defer logger.Sync()
|
defer logger.Sync()
|
||||||
|
|
||||||
logger.Info("Starting Server Party Service",
|
logger.Info("Starting Server Party Service",
|
||||||
zap.String("environment", cfg.Server.Environment),
|
zap.String("environment", cfg.Server.Environment),
|
||||||
zap.Int("http_port", cfg.Server.HTTPPort))
|
zap.Int("http_port", cfg.Server.HTTPPort))
|
||||||
|
|
||||||
// Initialize database connection
|
// Initialize database connection
|
||||||
db, err := initDatabase(cfg.Database)
|
db, err := initDatabase(cfg.Database)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to connect to database", zap.Error(err))
|
logger.Fatal("Failed to connect to database", zap.Error(err))
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
// Initialize crypto service with master key from environment
|
// Initialize crypto service with master key from environment
|
||||||
masterKeyHex := os.Getenv("MPC_CRYPTO_MASTER_KEY")
|
masterKeyHex := os.Getenv("MPC_CRYPTO_MASTER_KEY")
|
||||||
if masterKeyHex == "" {
|
if masterKeyHex == "" {
|
||||||
masterKeyHex = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" // 64 hex chars = 32 bytes
|
masterKeyHex = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" // 64 hex chars = 32 bytes
|
||||||
}
|
}
|
||||||
masterKey, err := hex.DecodeString(masterKeyHex)
|
masterKey, err := hex.DecodeString(masterKeyHex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Invalid master key format", zap.Error(err))
|
logger.Fatal("Invalid master key format", zap.Error(err))
|
||||||
}
|
}
|
||||||
cryptoService, err := crypto.NewCryptoService(masterKey)
|
cryptoService, err := crypto.NewCryptoService(masterKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to create crypto service", zap.Error(err))
|
logger.Fatal("Failed to create crypto service", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get gRPC service addresses from environment
|
// Get gRPC service addresses from environment
|
||||||
coordinatorAddr := os.Getenv("SESSION_COORDINATOR_ADDR")
|
coordinatorAddr := os.Getenv("SESSION_COORDINATOR_ADDR")
|
||||||
if coordinatorAddr == "" {
|
if coordinatorAddr == "" {
|
||||||
coordinatorAddr = "localhost:9091"
|
coordinatorAddr = "localhost:9091"
|
||||||
}
|
}
|
||||||
routerAddr := os.Getenv("MESSAGE_ROUTER_ADDR")
|
routerAddr := os.Getenv("MESSAGE_ROUTER_ADDR")
|
||||||
if routerAddr == "" {
|
if routerAddr == "" {
|
||||||
routerAddr = "localhost:9092"
|
routerAddr = "localhost:9092"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize gRPC clients
|
// Initialize gRPC clients
|
||||||
sessionClient, err := grpcclient.NewSessionCoordinatorClient(coordinatorAddr)
|
sessionClient, err := grpcclient.NewSessionCoordinatorClient(coordinatorAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to connect to session coordinator", zap.Error(err))
|
logger.Fatal("Failed to connect to session coordinator", zap.Error(err))
|
||||||
}
|
}
|
||||||
defer sessionClient.Close()
|
defer sessionClient.Close()
|
||||||
|
|
||||||
messageRouter, err := grpcclient.NewMessageRouterClient(routerAddr)
|
messageRouter, err := grpcclient.NewMessageRouterClient(routerAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to connect to message router", zap.Error(err))
|
logger.Fatal("Failed to connect to message router", zap.Error(err))
|
||||||
}
|
}
|
||||||
defer messageRouter.Close()
|
defer messageRouter.Close()
|
||||||
|
|
||||||
// Initialize repositories
|
// Initialize repositories
|
||||||
keyShareRepo := postgres.NewKeySharePostgresRepo(db)
|
keyShareRepo := postgres.NewKeySharePostgresRepo(db)
|
||||||
|
|
||||||
// Initialize use cases with real gRPC clients
|
// Initialize use cases with real gRPC clients
|
||||||
participateKeygenUC := use_cases.NewParticipateKeygenUseCase(
|
participateKeygenUC := use_cases.NewParticipateKeygenUseCase(
|
||||||
keyShareRepo,
|
keyShareRepo,
|
||||||
sessionClient,
|
sessionClient,
|
||||||
messageRouter,
|
messageRouter,
|
||||||
cryptoService,
|
cryptoService,
|
||||||
)
|
)
|
||||||
participateSigningUC := use_cases.NewParticipateSigningUseCase(
|
participateSigningUC := use_cases.NewParticipateSigningUseCase(
|
||||||
keyShareRepo,
|
keyShareRepo,
|
||||||
sessionClient,
|
sessionClient,
|
||||||
messageRouter,
|
messageRouter,
|
||||||
cryptoService,
|
cryptoService,
|
||||||
)
|
)
|
||||||
|
|
||||||
// Create shutdown context
|
// Create shutdown context
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start HTTP server
|
// Start HTTP server
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
if err := startHTTPServer(cfg, participateKeygenUC, participateSigningUC, keyShareRepo); err != nil {
|
if err := startHTTPServer(cfg, participateKeygenUC, participateSigningUC, keyShareRepo); err != nil {
|
||||||
errChan <- fmt.Errorf("HTTP server error: %w", err)
|
errChan <- fmt.Errorf("HTTP server error: %w", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Wait for shutdown signal
|
// Wait for shutdown signal
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case sig := <-sigChan:
|
case sig := <-sigChan:
|
||||||
logger.Info("Received shutdown signal", zap.String("signal", sig.String()))
|
logger.Info("Received shutdown signal", zap.String("signal", sig.String()))
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
logger.Error("Server error", zap.Error(err))
|
logger.Error("Server error", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Graceful shutdown
|
// Graceful shutdown
|
||||||
logger.Info("Shutting down...")
|
logger.Info("Shutting down...")
|
||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
logger.Info("Shutdown complete")
|
logger.Info("Shutdown complete")
|
||||||
|
|
||||||
_ = ctx
|
_ = ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
func initDatabase(cfg config.DatabaseConfig) (*sql.DB, error) {
|
func initDatabase(cfg config.DatabaseConfig) (*sql.DB, error) {
|
||||||
db, err := sql.Open("postgres", cfg.DSN())
|
const maxRetries = 10
|
||||||
if err != nil {
|
const retryDelay = 2 * time.Second
|
||||||
return nil, err
|
|
||||||
}
|
var db *sql.DB
|
||||||
|
var err error
|
||||||
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
|
||||||
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
for i := 0; i < maxRetries; i++ {
|
||||||
db.SetConnMaxLifetime(cfg.ConnMaxLife)
|
db, err = sql.Open("postgres", cfg.DSN())
|
||||||
|
if err != nil {
|
||||||
if err := db.Ping(); err != nil {
|
logger.Warn("Failed to open database connection, retrying...",
|
||||||
return nil, err
|
zap.Int("attempt", i+1),
|
||||||
}
|
zap.Int("max_retries", maxRetries),
|
||||||
|
zap.Error(err))
|
||||||
logger.Info("Connected to PostgreSQL")
|
time.Sleep(retryDelay * time.Duration(i+1))
|
||||||
return db, nil
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
func startHTTPServer(
|
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||||
cfg *config.Config,
|
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||||
participateKeygenUC *use_cases.ParticipateKeygenUseCase,
|
db.SetConnMaxLifetime(cfg.ConnMaxLife)
|
||||||
participateSigningUC *use_cases.ParticipateSigningUseCase,
|
|
||||||
keyShareRepo *postgres.KeySharePostgresRepo,
|
// Test connection with Ping
|
||||||
) error {
|
if err = db.Ping(); err != nil {
|
||||||
if cfg.Server.Environment == "production" {
|
logger.Warn("Failed to ping database, retrying...",
|
||||||
gin.SetMode(gin.ReleaseMode)
|
zap.Int("attempt", i+1),
|
||||||
}
|
zap.Int("max_retries", maxRetries),
|
||||||
|
zap.Error(err))
|
||||||
router := gin.New()
|
db.Close()
|
||||||
router.Use(gin.Recovery())
|
time.Sleep(retryDelay * time.Duration(i+1))
|
||||||
router.Use(gin.Logger())
|
continue
|
||||||
|
}
|
||||||
// Health check
|
|
||||||
router.GET("/health", func(c *gin.Context) {
|
// Verify database is actually usable with a simple query
|
||||||
c.JSON(http.StatusOK, gin.H{
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
"status": "healthy",
|
var result int
|
||||||
"service": "server-party",
|
err = db.QueryRowContext(ctx, "SELECT 1").Scan(&result)
|
||||||
})
|
cancel()
|
||||||
})
|
if err != nil {
|
||||||
|
logger.Warn("Database ping succeeded but query failed, retrying...",
|
||||||
// API routes
|
zap.Int("attempt", i+1),
|
||||||
api := router.Group("/api/v1")
|
zap.Int("max_retries", maxRetries),
|
||||||
{
|
zap.Error(err))
|
||||||
// Keygen participation endpoint
|
db.Close()
|
||||||
api.POST("/keygen/participate", func(c *gin.Context) {
|
time.Sleep(retryDelay * time.Duration(i+1))
|
||||||
var req struct {
|
continue
|
||||||
SessionID string `json:"session_id" binding:"required"`
|
}
|
||||||
PartyID string `json:"party_id" binding:"required"`
|
|
||||||
JoinToken string `json:"join_token" binding:"required"`
|
logger.Info("Connected to PostgreSQL and verified connectivity",
|
||||||
}
|
zap.Int("attempt", i+1))
|
||||||
|
return db, nil
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
}
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
return nil, fmt.Errorf("failed to connect to database after %d retries: %w", maxRetries, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID, err := uuid.Parse(req.SessionID)
|
func startHTTPServer(
|
||||||
if err != nil {
|
cfg *config.Config,
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session_id format"})
|
participateKeygenUC *use_cases.ParticipateKeygenUseCase,
|
||||||
return
|
participateSigningUC *use_cases.ParticipateSigningUseCase,
|
||||||
}
|
keyShareRepo *postgres.KeySharePostgresRepo,
|
||||||
|
) error {
|
||||||
// Execute keygen participation asynchronously
|
if cfg.Server.Environment == "production" {
|
||||||
go func() {
|
gin.SetMode(gin.ReleaseMode)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
}
|
||||||
defer cancel()
|
|
||||||
|
router := gin.New()
|
||||||
input := use_cases.ParticipateKeygenInput{
|
router.Use(gin.Recovery())
|
||||||
SessionID: sessionID,
|
router.Use(gin.Logger())
|
||||||
PartyID: req.PartyID,
|
|
||||||
JoinToken: req.JoinToken,
|
// Health check
|
||||||
}
|
router.GET("/health", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
output, err := participateKeygenUC.Execute(ctx, input)
|
"status": "healthy",
|
||||||
if err != nil {
|
"service": "server-party",
|
||||||
logger.Error("Keygen participation failed",
|
})
|
||||||
zap.String("session_id", req.SessionID),
|
})
|
||||||
zap.String("party_id", req.PartyID),
|
|
||||||
zap.Error(err))
|
// API routes
|
||||||
return
|
api := router.Group("/api/v1")
|
||||||
}
|
{
|
||||||
|
// Keygen participation endpoint
|
||||||
logger.Info("Keygen participation completed",
|
api.POST("/keygen/participate", func(c *gin.Context) {
|
||||||
zap.String("session_id", req.SessionID),
|
var req struct {
|
||||||
zap.String("party_id", req.PartyID),
|
SessionID string `json:"session_id" binding:"required"`
|
||||||
zap.Bool("success", output.Success))
|
PartyID string `json:"party_id" binding:"required"`
|
||||||
}()
|
JoinToken string `json:"join_token" binding:"required"`
|
||||||
|
}
|
||||||
c.JSON(http.StatusAccepted, gin.H{
|
|
||||||
"message": "keygen participation initiated",
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
"session_id": req.SessionID,
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
"party_id": req.PartyID,
|
return
|
||||||
})
|
}
|
||||||
})
|
|
||||||
|
sessionID, err := uuid.Parse(req.SessionID)
|
||||||
// Signing participation endpoint
|
if err != nil {
|
||||||
api.POST("/sign/participate", func(c *gin.Context) {
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session_id format"})
|
||||||
var req struct {
|
return
|
||||||
SessionID string `json:"session_id" binding:"required"`
|
}
|
||||||
PartyID string `json:"party_id" binding:"required"`
|
|
||||||
JoinToken string `json:"join_token" binding:"required"`
|
// Execute keygen participation asynchronously
|
||||||
MessageHash string `json:"message_hash"`
|
go func() {
|
||||||
}
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||||
|
defer cancel()
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
input := use_cases.ParticipateKeygenInput{
|
||||||
return
|
SessionID: sessionID,
|
||||||
}
|
PartyID: req.PartyID,
|
||||||
|
JoinToken: req.JoinToken,
|
||||||
sessionID, err := uuid.Parse(req.SessionID)
|
}
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session_id format"})
|
output, err := participateKeygenUC.Execute(ctx, input)
|
||||||
return
|
if err != nil {
|
||||||
}
|
logger.Error("Keygen participation failed",
|
||||||
|
zap.String("session_id", req.SessionID),
|
||||||
// Parse message hash if provided
|
zap.String("party_id", req.PartyID),
|
||||||
var messageHash []byte
|
zap.Error(err))
|
||||||
if req.MessageHash != "" {
|
return
|
||||||
messageHash, err = hex.DecodeString(req.MessageHash)
|
}
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid message_hash format (expected hex)"})
|
logger.Info("Keygen participation completed",
|
||||||
return
|
zap.String("session_id", req.SessionID),
|
||||||
}
|
zap.String("party_id", req.PartyID),
|
||||||
}
|
zap.Bool("success", output.Success))
|
||||||
|
}()
|
||||||
// Execute signing participation asynchronously
|
|
||||||
go func() {
|
c.JSON(http.StatusAccepted, gin.H{
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
"message": "keygen participation initiated",
|
||||||
defer cancel()
|
"session_id": req.SessionID,
|
||||||
|
"party_id": req.PartyID,
|
||||||
input := use_cases.ParticipateSigningInput{
|
})
|
||||||
SessionID: sessionID,
|
})
|
||||||
PartyID: req.PartyID,
|
|
||||||
JoinToken: req.JoinToken,
|
// Signing participation endpoint
|
||||||
MessageHash: messageHash,
|
api.POST("/sign/participate", func(c *gin.Context) {
|
||||||
}
|
var req struct {
|
||||||
|
SessionID string `json:"session_id" binding:"required"`
|
||||||
output, err := participateSigningUC.Execute(ctx, input)
|
PartyID string `json:"party_id" binding:"required"`
|
||||||
if err != nil {
|
JoinToken string `json:"join_token" binding:"required"`
|
||||||
logger.Error("Signing participation failed",
|
MessageHash string `json:"message_hash"`
|
||||||
zap.String("session_id", req.SessionID),
|
}
|
||||||
zap.String("party_id", req.PartyID),
|
|
||||||
zap.Error(err))
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
return
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
}
|
return
|
||||||
|
}
|
||||||
logger.Info("Signing participation completed",
|
|
||||||
zap.String("session_id", req.SessionID),
|
sessionID, err := uuid.Parse(req.SessionID)
|
||||||
zap.String("party_id", req.PartyID),
|
if err != nil {
|
||||||
zap.Bool("success", output.Success),
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session_id format"})
|
||||||
zap.Int("signature_len", len(output.Signature)))
|
return
|
||||||
}()
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusAccepted, gin.H{
|
// Parse message hash if provided
|
||||||
"message": "signing participation initiated",
|
var messageHash []byte
|
||||||
"session_id": req.SessionID,
|
if req.MessageHash != "" {
|
||||||
"party_id": req.PartyID,
|
messageHash, err = hex.DecodeString(req.MessageHash)
|
||||||
})
|
if err != nil {
|
||||||
})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid message_hash format (expected hex)"})
|
||||||
|
return
|
||||||
// Get key shares for a party
|
}
|
||||||
api.GET("/shares/:party_id", func(c *gin.Context) {
|
}
|
||||||
partyID := c.Param("party_id")
|
|
||||||
|
// Execute signing participation asynchronously
|
||||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
|
go func() {
|
||||||
defer cancel()
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||||
|
defer cancel()
|
||||||
shares, err := keyShareRepo.ListByParty(ctx, partyID)
|
|
||||||
if err != nil {
|
input := use_cases.ParticipateSigningInput{
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to fetch shares"})
|
SessionID: sessionID,
|
||||||
return
|
PartyID: req.PartyID,
|
||||||
}
|
JoinToken: req.JoinToken,
|
||||||
|
MessageHash: messageHash,
|
||||||
// Return share metadata (not the actual encrypted data)
|
}
|
||||||
shareInfos := make([]gin.H, len(shares))
|
|
||||||
for i, share := range shares {
|
output, err := participateSigningUC.Execute(ctx, input)
|
||||||
shareInfos[i] = gin.H{
|
if err != nil {
|
||||||
"id": share.ID.String(),
|
logger.Error("Signing participation failed",
|
||||||
"party_id": share.PartyID,
|
zap.String("session_id", req.SessionID),
|
||||||
"party_index": share.PartyIndex,
|
zap.String("party_id", req.PartyID),
|
||||||
"public_key": hex.EncodeToString(share.PublicKey),
|
zap.Error(err))
|
||||||
"created_at": share.CreatedAt,
|
return
|
||||||
"last_used": share.LastUsedAt,
|
}
|
||||||
}
|
|
||||||
}
|
logger.Info("Signing participation completed",
|
||||||
|
zap.String("session_id", req.SessionID),
|
||||||
c.JSON(http.StatusOK, gin.H{
|
zap.String("party_id", req.PartyID),
|
||||||
"party_id": partyID,
|
zap.Bool("success", output.Success),
|
||||||
"count": len(shares),
|
zap.Int("signature_len", len(output.Signature)))
|
||||||
"shares": shareInfos,
|
}()
|
||||||
})
|
|
||||||
})
|
c.JSON(http.StatusAccepted, gin.H{
|
||||||
}
|
"message": "signing participation initiated",
|
||||||
|
"session_id": req.SessionID,
|
||||||
logger.Info("Starting HTTP server", zap.Int("port", cfg.Server.HTTPPort))
|
"party_id": req.PartyID,
|
||||||
return router.Run(fmt.Sprintf(":%d", cfg.Server.HTTPPort))
|
})
|
||||||
}
|
})
|
||||||
|
|
||||||
|
// Get key shares for a party
|
||||||
|
api.GET("/shares/:party_id", func(c *gin.Context) {
|
||||||
|
partyID := c.Param("party_id")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
shares, err := keyShareRepo.ListByParty(ctx, partyID)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to fetch shares"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return share metadata (not the actual encrypted data)
|
||||||
|
shareInfos := make([]gin.H, len(shares))
|
||||||
|
for i, share := range shares {
|
||||||
|
shareInfos[i] = gin.H{
|
||||||
|
"id": share.ID.String(),
|
||||||
|
"party_id": share.PartyID,
|
||||||
|
"party_index": share.PartyIndex,
|
||||||
|
"public_key": hex.EncodeToString(share.PublicKey),
|
||||||
|
"created_at": share.CreatedAt,
|
||||||
|
"last_used": share.LastUsedAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"party_id": partyID,
|
||||||
|
"count": len(shares),
|
||||||
|
"shares": shareInfos,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Starting HTTP server", zap.Int("port", cfg.Server.HTTPPort))
|
||||||
|
return router.Run(fmt.Sprintf(":%d", cfg.Server.HTTPPort))
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,56 +1,56 @@
|
||||||
package entities
|
package entities
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PartyKeyShare represents the server's key share
|
// PartyKeyShare represents the server's key share
|
||||||
type PartyKeyShare struct {
|
type PartyKeyShare struct {
|
||||||
ID uuid.UUID
|
ID uuid.UUID
|
||||||
PartyID string
|
PartyID string
|
||||||
PartyIndex int
|
PartyIndex int
|
||||||
SessionID uuid.UUID // Keygen session ID
|
SessionID uuid.UUID // Keygen session ID
|
||||||
ThresholdN int
|
ThresholdN int
|
||||||
ThresholdT int
|
ThresholdT int
|
||||||
ShareData []byte // Encrypted tss-lib LocalPartySaveData
|
ShareData []byte // Encrypted tss-lib LocalPartySaveData
|
||||||
PublicKey []byte // Group public key
|
PublicKey []byte // Group public key
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
LastUsedAt *time.Time
|
LastUsedAt *time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPartyKeyShare creates a new party key share
|
// NewPartyKeyShare creates a new party key share
|
||||||
func NewPartyKeyShare(
|
func NewPartyKeyShare(
|
||||||
partyID string,
|
partyID string,
|
||||||
partyIndex int,
|
partyIndex int,
|
||||||
sessionID uuid.UUID,
|
sessionID uuid.UUID,
|
||||||
thresholdN, thresholdT int,
|
thresholdN, thresholdT int,
|
||||||
shareData, publicKey []byte,
|
shareData, publicKey []byte,
|
||||||
) *PartyKeyShare {
|
) *PartyKeyShare {
|
||||||
return &PartyKeyShare{
|
return &PartyKeyShare{
|
||||||
ID: uuid.New(),
|
ID: uuid.New(),
|
||||||
PartyID: partyID,
|
PartyID: partyID,
|
||||||
PartyIndex: partyIndex,
|
PartyIndex: partyIndex,
|
||||||
SessionID: sessionID,
|
SessionID: sessionID,
|
||||||
ThresholdN: thresholdN,
|
ThresholdN: thresholdN,
|
||||||
ThresholdT: thresholdT,
|
ThresholdT: thresholdT,
|
||||||
ShareData: shareData,
|
ShareData: shareData,
|
||||||
PublicKey: publicKey,
|
PublicKey: publicKey,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkUsed updates the last used timestamp
|
// MarkUsed updates the last used timestamp
|
||||||
func (k *PartyKeyShare) MarkUsed() {
|
func (k *PartyKeyShare) MarkUsed() {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
k.LastUsedAt = &now
|
k.LastUsedAt = &now
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValid checks if the key share is valid
|
// IsValid checks if the key share is valid
|
||||||
func (k *PartyKeyShare) IsValid() bool {
|
func (k *PartyKeyShare) IsValid() bool {
|
||||||
return k.ID != uuid.Nil &&
|
return k.ID != uuid.Nil &&
|
||||||
k.PartyID != "" &&
|
k.PartyID != "" &&
|
||||||
len(k.ShareData) > 0 &&
|
len(k.ShareData) > 0 &&
|
||||||
len(k.PublicKey) > 0
|
len(k.PublicKey) > 0
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,32 +1,32 @@
|
||||||
package repositories
|
package repositories
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rwadurian/mpc-system/services/server-party/domain/entities"
|
"github.com/rwadurian/mpc-system/services/server-party/domain/entities"
|
||||||
)
|
)
|
||||||
|
|
||||||
// KeyShareRepository defines the interface for key share persistence
|
// KeyShareRepository defines the interface for key share persistence
|
||||||
type KeyShareRepository interface {
|
type KeyShareRepository interface {
|
||||||
// Save persists a new key share
|
// Save persists a new key share
|
||||||
Save(ctx context.Context, keyShare *entities.PartyKeyShare) error
|
Save(ctx context.Context, keyShare *entities.PartyKeyShare) error
|
||||||
|
|
||||||
// FindByID retrieves a key share by ID
|
// FindByID retrieves a key share by ID
|
||||||
FindByID(ctx context.Context, id uuid.UUID) (*entities.PartyKeyShare, error)
|
FindByID(ctx context.Context, id uuid.UUID) (*entities.PartyKeyShare, error)
|
||||||
|
|
||||||
// FindBySessionAndParty retrieves a key share by session and party
|
// FindBySessionAndParty retrieves a key share by session and party
|
||||||
FindBySessionAndParty(ctx context.Context, sessionID uuid.UUID, partyID string) (*entities.PartyKeyShare, error)
|
FindBySessionAndParty(ctx context.Context, sessionID uuid.UUID, partyID string) (*entities.PartyKeyShare, error)
|
||||||
|
|
||||||
// FindByPublicKey retrieves key shares by public key
|
// FindByPublicKey retrieves key shares by public key
|
||||||
FindByPublicKey(ctx context.Context, publicKey []byte) ([]*entities.PartyKeyShare, error)
|
FindByPublicKey(ctx context.Context, publicKey []byte) ([]*entities.PartyKeyShare, error)
|
||||||
|
|
||||||
// Update updates an existing key share
|
// Update updates an existing key share
|
||||||
Update(ctx context.Context, keyShare *entities.PartyKeyShare) error
|
Update(ctx context.Context, keyShare *entities.PartyKeyShare) error
|
||||||
|
|
||||||
// Delete removes a key share
|
// Delete removes a key share
|
||||||
Delete(ctx context.Context, id uuid.UUID) error
|
Delete(ctx context.Context, id uuid.UUID) error
|
||||||
|
|
||||||
// ListByParty lists all key shares for a party
|
// ListByParty lists all key shares for a party
|
||||||
ListByParty(ctx context.Context, partyID string) ([]*entities.PartyKeyShare, error)
|
ListByParty(ctx context.Context, partyID string) ([]*entities.PartyKeyShare, error)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,52 +1,52 @@
|
||||||
# Build stage
|
# Build stage
|
||||||
FROM golang:1.21-alpine AS builder
|
FROM golang:1.21-alpine AS builder
|
||||||
|
|
||||||
# Install dependencies
|
# Install dependencies
|
||||||
RUN apk add --no-cache git ca-certificates
|
RUN apk add --no-cache git ca-certificates
|
||||||
|
|
||||||
# Set Go proxy (can be overridden with --build-arg GOPROXY=...)
|
# Set Go proxy (can be overridden with --build-arg GOPROXY=...)
|
||||||
ARG GOPROXY=https://proxy.golang.org,direct
|
ARG GOPROXY=https://proxy.golang.org,direct
|
||||||
ENV GOPROXY=${GOPROXY}
|
ENV GOPROXY=${GOPROXY}
|
||||||
|
|
||||||
# Set working directory
|
# Set working directory
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Copy go mod files
|
# Copy go mod files
|
||||||
COPY go.mod go.sum ./
|
COPY go.mod go.sum ./
|
||||||
|
|
||||||
# Download dependencies
|
# Download dependencies
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
|
|
||||||
# Copy source code
|
# Copy source code
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# Build the application
|
# Build the application
|
||||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
||||||
-ldflags="-w -s" \
|
-ldflags="-w -s" \
|
||||||
-o /bin/session-coordinator \
|
-o /bin/session-coordinator \
|
||||||
./services/session-coordinator/cmd/server
|
./services/session-coordinator/cmd/server
|
||||||
|
|
||||||
# Final stage
|
# Final stage
|
||||||
FROM alpine:3.18
|
FROM alpine:3.18
|
||||||
|
|
||||||
# Install ca-certificates and curl for HTTPS and health check
|
# Install ca-certificates and curl for HTTPS and health check
|
||||||
RUN apk --no-cache add ca-certificates curl
|
RUN apk --no-cache add ca-certificates curl
|
||||||
|
|
||||||
# Create non-root user
|
# Create non-root user
|
||||||
RUN adduser -D -s /bin/sh mpc
|
RUN adduser -D -s /bin/sh mpc
|
||||||
|
|
||||||
# Copy binary from builder
|
# Copy binary from builder
|
||||||
COPY --from=builder /bin/session-coordinator /bin/session-coordinator
|
COPY --from=builder /bin/session-coordinator /bin/session-coordinator
|
||||||
|
|
||||||
# Switch to non-root user
|
# Switch to non-root user
|
||||||
USER mpc
|
USER mpc
|
||||||
|
|
||||||
# Expose ports
|
# Expose ports
|
||||||
EXPOSE 50051 8080
|
EXPOSE 50051 8080
|
||||||
|
|
||||||
# Health check
|
# Health check
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||||
CMD curl -sf http://localhost:8080/health || exit 1
|
CMD curl -sf http://localhost:8080/health || exit 1
|
||||||
|
|
||||||
# Run the application
|
# Run the application
|
||||||
ENTRYPOINT ["/bin/session-coordinator"]
|
ENTRYPOINT ["/bin/session-coordinator"]
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,276 +1,276 @@
|
||||||
package postgres
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/entities"
|
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/entities"
|
||||||
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/repositories"
|
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/repositories"
|
||||||
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects"
|
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MessagePostgresRepo implements MessageRepository for PostgreSQL
|
// MessagePostgresRepo implements MessageRepository for PostgreSQL
|
||||||
type MessagePostgresRepo struct {
|
type MessagePostgresRepo struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMessagePostgresRepo creates a new PostgreSQL message repository
|
// NewMessagePostgresRepo creates a new PostgreSQL message repository
|
||||||
func NewMessagePostgresRepo(db *sql.DB) *MessagePostgresRepo {
|
func NewMessagePostgresRepo(db *sql.DB) *MessagePostgresRepo {
|
||||||
return &MessagePostgresRepo{db: db}
|
return &MessagePostgresRepo{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveMessage persists a new message
|
// SaveMessage persists a new message
|
||||||
func (r *MessagePostgresRepo) SaveMessage(ctx context.Context, msg *entities.SessionMessage) error {
|
func (r *MessagePostgresRepo) SaveMessage(ctx context.Context, msg *entities.SessionMessage) error {
|
||||||
toParties := msg.GetToPartyStrings()
|
toParties := msg.GetToPartyStrings()
|
||||||
|
|
||||||
_, err := r.db.ExecContext(ctx, `
|
_, err := r.db.ExecContext(ctx, `
|
||||||
INSERT INTO mpc_messages (
|
INSERT INTO mpc_messages (
|
||||||
id, session_id, from_party, to_parties, round_number, message_type, payload, created_at
|
id, session_id, from_party, to_parties, round_number, message_type, payload, created_at
|
||||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
`,
|
`,
|
||||||
msg.ID,
|
msg.ID,
|
||||||
msg.SessionID.UUID(),
|
msg.SessionID.UUID(),
|
||||||
msg.FromParty.String(),
|
msg.FromParty.String(),
|
||||||
pq.Array(toParties),
|
pq.Array(toParties),
|
||||||
msg.RoundNumber,
|
msg.RoundNumber,
|
||||||
msg.MessageType,
|
msg.MessageType,
|
||||||
msg.Payload,
|
msg.Payload,
|
||||||
msg.CreatedAt,
|
msg.CreatedAt,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByID retrieves a message by ID
|
// GetByID retrieves a message by ID
|
||||||
func (r *MessagePostgresRepo) GetByID(ctx context.Context, id uuid.UUID) (*entities.SessionMessage, error) {
|
func (r *MessagePostgresRepo) GetByID(ctx context.Context, id uuid.UUID) (*entities.SessionMessage, error) {
|
||||||
var row messageRow
|
var row messageRow
|
||||||
var toParties []string
|
var toParties []string
|
||||||
|
|
||||||
err := r.db.QueryRowContext(ctx, `
|
err := r.db.QueryRowContext(ctx, `
|
||||||
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
||||||
FROM mpc_messages WHERE id = $1
|
FROM mpc_messages WHERE id = $1
|
||||||
`, id).Scan(
|
`, id).Scan(
|
||||||
&row.ID,
|
&row.ID,
|
||||||
&row.SessionID,
|
&row.SessionID,
|
||||||
&row.FromParty,
|
&row.FromParty,
|
||||||
pq.Array(&toParties),
|
pq.Array(&toParties),
|
||||||
&row.RoundNumber,
|
&row.RoundNumber,
|
||||||
&row.MessageType,
|
&row.MessageType,
|
||||||
&row.Payload,
|
&row.Payload,
|
||||||
&row.CreatedAt,
|
&row.CreatedAt,
|
||||||
&row.DeliveredAt,
|
&row.DeliveredAt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.rowToMessage(row, toParties)
|
return r.rowToMessage(row, toParties)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMessages retrieves messages for a session and party after a specific time
|
// GetMessages retrieves messages for a session and party after a specific time
|
||||||
func (r *MessagePostgresRepo) GetMessages(
|
func (r *MessagePostgresRepo) GetMessages(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID value_objects.SessionID,
|
sessionID value_objects.SessionID,
|
||||||
partyID value_objects.PartyID,
|
partyID value_objects.PartyID,
|
||||||
afterTime time.Time,
|
afterTime time.Time,
|
||||||
) ([]*entities.SessionMessage, error) {
|
) ([]*entities.SessionMessage, error) {
|
||||||
rows, err := r.db.QueryContext(ctx, `
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
||||||
FROM mpc_messages
|
FROM mpc_messages
|
||||||
WHERE session_id = $1
|
WHERE session_id = $1
|
||||||
AND created_at > $2
|
AND created_at > $2
|
||||||
AND (to_parties IS NULL OR $3 = ANY(to_parties))
|
AND (to_parties IS NULL OR $3 = ANY(to_parties))
|
||||||
AND from_party != $3
|
AND from_party != $3
|
||||||
ORDER BY created_at ASC
|
ORDER BY created_at ASC
|
||||||
`, sessionID.UUID(), afterTime, partyID.String())
|
`, sessionID.UUID(), afterTime, partyID.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
return r.scanMessages(rows)
|
return r.scanMessages(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUndeliveredMessages retrieves undelivered messages for a party
|
// GetUndeliveredMessages retrieves undelivered messages for a party
|
||||||
func (r *MessagePostgresRepo) GetUndeliveredMessages(
|
func (r *MessagePostgresRepo) GetUndeliveredMessages(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID value_objects.SessionID,
|
sessionID value_objects.SessionID,
|
||||||
partyID value_objects.PartyID,
|
partyID value_objects.PartyID,
|
||||||
) ([]*entities.SessionMessage, error) {
|
) ([]*entities.SessionMessage, error) {
|
||||||
rows, err := r.db.QueryContext(ctx, `
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
||||||
FROM mpc_messages
|
FROM mpc_messages
|
||||||
WHERE session_id = $1
|
WHERE session_id = $1
|
||||||
AND delivered_at IS NULL
|
AND delivered_at IS NULL
|
||||||
AND (to_parties IS NULL OR $2 = ANY(to_parties))
|
AND (to_parties IS NULL OR $2 = ANY(to_parties))
|
||||||
AND from_party != $2
|
AND from_party != $2
|
||||||
ORDER BY created_at ASC
|
ORDER BY created_at ASC
|
||||||
`, sessionID.UUID(), partyID.String())
|
`, sessionID.UUID(), partyID.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
return r.scanMessages(rows)
|
return r.scanMessages(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMessagesByRound retrieves messages for a specific round
|
// GetMessagesByRound retrieves messages for a specific round
|
||||||
func (r *MessagePostgresRepo) GetMessagesByRound(
|
func (r *MessagePostgresRepo) GetMessagesByRound(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID value_objects.SessionID,
|
sessionID value_objects.SessionID,
|
||||||
roundNumber int,
|
roundNumber int,
|
||||||
) ([]*entities.SessionMessage, error) {
|
) ([]*entities.SessionMessage, error) {
|
||||||
rows, err := r.db.QueryContext(ctx, `
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
||||||
FROM mpc_messages
|
FROM mpc_messages
|
||||||
WHERE session_id = $1 AND round_number = $2
|
WHERE session_id = $1 AND round_number = $2
|
||||||
ORDER BY created_at ASC
|
ORDER BY created_at ASC
|
||||||
`, sessionID.UUID(), roundNumber)
|
`, sessionID.UUID(), roundNumber)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
return r.scanMessages(rows)
|
return r.scanMessages(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkDelivered marks a message as delivered
|
// MarkDelivered marks a message as delivered
|
||||||
func (r *MessagePostgresRepo) MarkDelivered(ctx context.Context, messageID uuid.UUID) error {
|
func (r *MessagePostgresRepo) MarkDelivered(ctx context.Context, messageID uuid.UUID) error {
|
||||||
_, err := r.db.ExecContext(ctx, `
|
_, err := r.db.ExecContext(ctx, `
|
||||||
UPDATE mpc_messages SET delivered_at = NOW() WHERE id = $1
|
UPDATE mpc_messages SET delivered_at = NOW() WHERE id = $1
|
||||||
`, messageID)
|
`, messageID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkAllDelivered marks all messages for a party as delivered
|
// MarkAllDelivered marks all messages for a party as delivered
|
||||||
func (r *MessagePostgresRepo) MarkAllDelivered(
|
func (r *MessagePostgresRepo) MarkAllDelivered(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID value_objects.SessionID,
|
sessionID value_objects.SessionID,
|
||||||
partyID value_objects.PartyID,
|
partyID value_objects.PartyID,
|
||||||
) error {
|
) error {
|
||||||
_, err := r.db.ExecContext(ctx, `
|
_, err := r.db.ExecContext(ctx, `
|
||||||
UPDATE mpc_messages SET delivered_at = NOW()
|
UPDATE mpc_messages SET delivered_at = NOW()
|
||||||
WHERE session_id = $1
|
WHERE session_id = $1
|
||||||
AND delivered_at IS NULL
|
AND delivered_at IS NULL
|
||||||
AND (to_parties IS NULL OR $2 = ANY(to_parties))
|
AND (to_parties IS NULL OR $2 = ANY(to_parties))
|
||||||
`, sessionID.UUID(), partyID.String())
|
`, sessionID.UUID(), partyID.String())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteBySession deletes all messages for a session
|
// DeleteBySession deletes all messages for a session
|
||||||
func (r *MessagePostgresRepo) DeleteBySession(ctx context.Context, sessionID value_objects.SessionID) error {
|
func (r *MessagePostgresRepo) DeleteBySession(ctx context.Context, sessionID value_objects.SessionID) error {
|
||||||
_, err := r.db.ExecContext(ctx, `DELETE FROM mpc_messages WHERE session_id = $1`, sessionID.UUID())
|
_, err := r.db.ExecContext(ctx, `DELETE FROM mpc_messages WHERE session_id = $1`, sessionID.UUID())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteOlderThan deletes messages older than a specific time
|
// DeleteOlderThan deletes messages older than a specific time
|
||||||
func (r *MessagePostgresRepo) DeleteOlderThan(ctx context.Context, before time.Time) (int64, error) {
|
func (r *MessagePostgresRepo) DeleteOlderThan(ctx context.Context, before time.Time) (int64, error) {
|
||||||
result, err := r.db.ExecContext(ctx, `DELETE FROM mpc_messages WHERE created_at < $1`, before)
|
result, err := r.db.ExecContext(ctx, `DELETE FROM mpc_messages WHERE created_at < $1`, before)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return result.RowsAffected()
|
return result.RowsAffected()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count returns the total number of messages for a session
|
// Count returns the total number of messages for a session
|
||||||
func (r *MessagePostgresRepo) Count(ctx context.Context, sessionID value_objects.SessionID) (int64, error) {
|
func (r *MessagePostgresRepo) Count(ctx context.Context, sessionID value_objects.SessionID) (int64, error) {
|
||||||
var count int64
|
var count int64
|
||||||
err := r.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM mpc_messages WHERE session_id = $1`, sessionID.UUID()).Scan(&count)
|
err := r.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM mpc_messages WHERE session_id = $1`, sessionID.UUID()).Scan(&count)
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountUndelivered returns the number of undelivered messages for a party
|
// CountUndelivered returns the number of undelivered messages for a party
|
||||||
func (r *MessagePostgresRepo) CountUndelivered(
|
func (r *MessagePostgresRepo) CountUndelivered(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID value_objects.SessionID,
|
sessionID value_objects.SessionID,
|
||||||
partyID value_objects.PartyID,
|
partyID value_objects.PartyID,
|
||||||
) (int64, error) {
|
) (int64, error) {
|
||||||
var count int64
|
var count int64
|
||||||
err := r.db.QueryRowContext(ctx, `
|
err := r.db.QueryRowContext(ctx, `
|
||||||
SELECT COUNT(*) FROM mpc_messages
|
SELECT COUNT(*) FROM mpc_messages
|
||||||
WHERE session_id = $1
|
WHERE session_id = $1
|
||||||
AND delivered_at IS NULL
|
AND delivered_at IS NULL
|
||||||
AND (to_parties IS NULL OR $2 = ANY(to_parties))
|
AND (to_parties IS NULL OR $2 = ANY(to_parties))
|
||||||
`, sessionID.UUID(), partyID.String()).Scan(&count)
|
`, sessionID.UUID(), partyID.String()).Scan(&count)
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper methods
|
// Helper methods
|
||||||
|
|
||||||
func (r *MessagePostgresRepo) scanMessages(rows *sql.Rows) ([]*entities.SessionMessage, error) {
|
func (r *MessagePostgresRepo) scanMessages(rows *sql.Rows) ([]*entities.SessionMessage, error) {
|
||||||
var messages []*entities.SessionMessage
|
var messages []*entities.SessionMessage
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var row messageRow
|
var row messageRow
|
||||||
var toParties []string
|
var toParties []string
|
||||||
|
|
||||||
err := rows.Scan(
|
err := rows.Scan(
|
||||||
&row.ID,
|
&row.ID,
|
||||||
&row.SessionID,
|
&row.SessionID,
|
||||||
&row.FromParty,
|
&row.FromParty,
|
||||||
pq.Array(&toParties),
|
pq.Array(&toParties),
|
||||||
&row.RoundNumber,
|
&row.RoundNumber,
|
||||||
&row.MessageType,
|
&row.MessageType,
|
||||||
&row.Payload,
|
&row.Payload,
|
||||||
&row.CreatedAt,
|
&row.CreatedAt,
|
||||||
&row.DeliveredAt,
|
&row.DeliveredAt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, err := r.rowToMessage(row, toParties)
|
msg, err := r.rowToMessage(row, toParties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
messages = append(messages, msg)
|
messages = append(messages, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return messages, rows.Err()
|
return messages, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *MessagePostgresRepo) rowToMessage(row messageRow, toParties []string) (*entities.SessionMessage, error) {
|
func (r *MessagePostgresRepo) rowToMessage(row messageRow, toParties []string) (*entities.SessionMessage, error) {
|
||||||
fromParty, err := value_objects.NewPartyID(row.FromParty)
|
fromParty, err := value_objects.NewPartyID(row.FromParty)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var toPartiesVO []value_objects.PartyID
|
var toPartiesVO []value_objects.PartyID
|
||||||
for _, p := range toParties {
|
for _, p := range toParties {
|
||||||
partyID, err := value_objects.NewPartyID(p)
|
partyID, err := value_objects.NewPartyID(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
toPartiesVO = append(toPartiesVO, partyID)
|
toPartiesVO = append(toPartiesVO, partyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &entities.SessionMessage{
|
return &entities.SessionMessage{
|
||||||
ID: row.ID,
|
ID: row.ID,
|
||||||
SessionID: value_objects.SessionIDFromUUID(row.SessionID),
|
SessionID: value_objects.SessionIDFromUUID(row.SessionID),
|
||||||
FromParty: fromParty,
|
FromParty: fromParty,
|
||||||
ToParties: toPartiesVO,
|
ToParties: toPartiesVO,
|
||||||
RoundNumber: row.RoundNumber,
|
RoundNumber: row.RoundNumber,
|
||||||
MessageType: row.MessageType,
|
MessageType: row.MessageType,
|
||||||
Payload: row.Payload,
|
Payload: row.Payload,
|
||||||
CreatedAt: row.CreatedAt,
|
CreatedAt: row.CreatedAt,
|
||||||
DeliveredAt: row.DeliveredAt,
|
DeliveredAt: row.DeliveredAt,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type messageRow struct {
|
type messageRow struct {
|
||||||
ID uuid.UUID
|
ID uuid.UUID
|
||||||
SessionID uuid.UUID
|
SessionID uuid.UUID
|
||||||
FromParty string
|
FromParty string
|
||||||
RoundNumber int
|
RoundNumber int
|
||||||
MessageType string
|
MessageType string
|
||||||
Payload []byte
|
Payload []byte
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
DeliveredAt *time.Time
|
DeliveredAt *time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure interface compliance
|
// Ensure interface compliance
|
||||||
var _ repositories.MessageRepository = (*MessagePostgresRepo)(nil)
|
var _ repositories.MessageRepository = (*MessagePostgresRepo)(nil)
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue