From 136a5ba8515f732f790ee01f8a6c165afd0391d7 Mon Sep 17 00:00:00 2001 From: hailin Date: Thu, 1 Jan 2026 07:48:52 -0800 Subject: [PATCH] fix(android): change address format from Cosmos to EVM and fix balance query MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changes: - Change address derivation from deriveKavaAddress to deriveEvmAddress in TssRepository.kt (3 locations) - Add AddressUtils.isEvmAddress() and getEvmAddress() helper methods to handle both old Cosmos and new EVM address formats - Fix balance query for old wallets by deriving EVM address from public key when needed (MainViewModel.fetchBalanceForShare) - Add retry logic for optimistic lock conflicts in join_session.go to prevent party_index collision during concurrent joins 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../tssparty/data/repository/TssRepository.kt | 223 ++++++++++++++++-- .../presentation/viewmodel/MainViewModel.kt | 21 +- .../com/durian/tssparty/util/AddressUtils.kt | 21 ++ .../application/use_cases/join_session.go | 43 +++- 4 files changed, 281 insertions(+), 27 deletions(-) diff --git a/backend/mpc-system/services/service-party-android/app/src/main/java/com/durian/tssparty/data/repository/TssRepository.kt b/backend/mpc-system/services/service-party-android/app/src/main/java/com/durian/tssparty/data/repository/TssRepository.kt index 77151808..e13d0796 100644 --- a/backend/mpc-system/services/service-party-android/app/src/main/java/com/durian/tssparty/data/repository/TssRepository.kt +++ b/backend/mpc-system/services/service-party-android/app/src/main/java/com/durian/tssparty/data/repository/TssRepository.kt @@ -349,6 +349,77 @@ class TssRepository @Inject constructor( } } + /** + * Get session status with full participant list + * Calls account-service API: GET /api/v1/co-managed/sessions/{sessionId} + * This is critical for getting the complete participant list before starting keygen + */ + suspend fun getSessionStatus(sessionId: String): Result { + return withContext(Dispatchers.IO) { + try { + val request = okhttp3.Request.Builder() + .url("$accountServiceUrl/api/v1/co-managed/sessions/$sessionId") + .get() + .build() + + android.util.Log.d("TssRepository", "Getting session status: $sessionId") + + val response = httpClient.newCall(request).execute() + val responseBody = response.body?.string() + ?: return@withContext Result.failure(Exception("空响应")) + + android.util.Log.d("TssRepository", "Session status response: $responseBody") + + if (!response.isSuccessful) { + val errorJson = try { + com.google.gson.JsonParser.parseString(responseBody).asJsonObject + } catch (e: Exception) { null } + val errorMsg = errorJson?.get("message")?.asString + ?: errorJson?.get("error")?.asString + ?: "HTTP ${response.code}" + return@withContext Result.failure(Exception(errorMsg)) + } + + // Parse response + val json = com.google.gson.JsonParser.parseString(responseBody).asJsonObject + val status = json.get("status")?.asString ?: "" + val thresholdT = json.get("threshold_t")?.asInt ?: 2 + val thresholdN = json.get("threshold_n")?.asInt ?: json.get("total_parties")?.asInt ?: 3 + val completedParties = json.get("completed_parties")?.asInt ?: 0 + val totalParties = json.get("total_parties")?.asInt ?: thresholdN + val sessionType = json.get("session_type")?.asString ?: "" + + // Parse participants array with party_index + val participantsArray = json.get("participants")?.asJsonArray + val participants = mutableListOf() + participantsArray?.forEach { elem -> + val p = elem.asJsonObject + participants.add(ParticipantStatusInfo( + partyId = p.get("party_id")?.asString ?: "", + partyIndex = p.get("party_index")?.asInt ?: 0, + status = p.get("status")?.asString ?: "" + )) + } + + android.util.Log.d("TssRepository", "Session status: status=$status, participants=${participants.size}, threshold=$thresholdT-of-$thresholdN") + + Result.success(SessionStatusResponse( + sessionId = sessionId, + status = status, + thresholdT = thresholdT, + thresholdN = thresholdN, + completedParties = completedParties, + totalParties = totalParties, + sessionType = sessionType, + participants = participants + )) + } catch (e: Exception) { + android.util.Log.e("TssRepository", "Get session status failed", e) + Result.failure(e) + } + } + } + /** * Validate an invite code and get session info * Calls account-service API: GET /api/v1/co-managed/sessions/by-invite-code/{inviteCode} @@ -557,14 +628,53 @@ class TssRepository @Inject constructor( android.util.Log.d("TssRepository", "Executing keygen as joiner: sessionId=$sessionId, partyIndex=$partyIndex") - // Start TSS keygen + // CRITICAL: Fetch complete participant list from backend before starting keygen + // This ensures we have all participants even if some joined after we did + val statusResult = getSessionStatus(sessionId) + if (statusResult.isFailure) { + android.util.Log.e("TssRepository", "Failed to get session status before keygen", statusResult.exceptionOrNull()) + return@coroutineScope Result.failure(statusResult.exceptionOrNull()!!) + } + + val sessionStatus = statusResult.getOrThrow() + + // Build complete participant list from backend response + val updatedParticipants = mutableListOf() + for (p in sessionStatus.participants) { + val name = if (p.partyId == partyId) { + session.participants.find { it.partyId == partyId }?.name ?: "我" + } else { + session.participants.find { it.partyId == p.partyId }?.name ?: "参与方 ${p.partyIndex + 1}" + } + updatedParticipants.add(Participant(p.partyId, p.partyIndex, name)) + } + + // Find my party index from the updated list (may differ from initial partyIndex if backend assigned differently) + val myInfo = updatedParticipants.find { it.partyId == partyId } + val actualPartyIndex = myInfo?.partyIndex ?: partyIndex + + android.util.Log.d("TssRepository", "Joiner updated participants: ${updatedParticipants.map { "${it.partyId.take(8)}:${it.partyIndex}" }}") + android.util.Log.d("TssRepository", "My actual party index: $actualPartyIndex (original: $partyIndex)") + + // Update the session with complete participant list + _currentSession.value = session.copy( + participants = updatedParticipants, + thresholdT = sessionStatus.thresholdT, + thresholdN = sessionStatus.thresholdN + ) + + // Use thresholds from backend (they may be more accurate) + val actualThresholdT = sessionStatus.thresholdT + val actualThresholdN = sessionStatus.thresholdN + + // Start TSS keygen with complete participant list val startResult = tssNativeBridge.startKeygen( sessionId = sessionId, partyId = partyId, - partyIndex = partyIndex, - thresholdT = thresholdT, - thresholdN = thresholdN, - participants = session.participants, + partyIndex = actualPartyIndex, + thresholdT = actualThresholdT, + thresholdN = actualThresholdN, + participants = updatedParticipants, password = password ) @@ -588,16 +698,16 @@ class TssRepository @Inject constructor( // Derive address from public key val publicKeyBytes = Base64.decode(result.publicKey, Base64.NO_WRAP) - val address = AddressUtils.deriveKavaAddress(publicKeyBytes) + val address = AddressUtils.deriveEvmAddress(publicKeyBytes) - // Save share record + // Save share record (use actual thresholds and party index from backend) val shareEntity = ShareRecordEntity( sessionId = sessionId, publicKey = result.publicKey, encryptedShare = result.encryptedShare, - thresholdT = thresholdT, - thresholdN = thresholdN, - partyIndex = partyIndex, + thresholdT = actualThresholdT, + thresholdN = actualThresholdN, + partyIndex = actualPartyIndex, address = address ) val id = shareRecordDao.insertShare(shareEntity) @@ -607,7 +717,7 @@ class TssRepository @Inject constructor( _sessionStatus.value = SessionStatus.COMPLETED - android.util.Log.d("TssRepository", "Keygen as joiner completed: address=$address") + android.util.Log.d("TssRepository", "Keygen as joiner completed: address=$address, partyIndex=$actualPartyIndex") Result.success(shareEntity.copy(id = id).toShareRecord()) @@ -872,7 +982,7 @@ class TssRepository @Inject constructor( // Derive address from public key val publicKeyBytes = Base64.decode(result.publicKey, Base64.NO_WRAP) - val address = AddressUtils.deriveKavaAddress(publicKeyBytes) + val address = AddressUtils.deriveEvmAddress(publicKeyBytes) // Save share record val shareEntity = ShareRecordEntity( @@ -1194,6 +1304,13 @@ class TssRepository @Inject constructor( /** * Start keygen as initiator (called when session_started event is received) + * + * IMPORTANT: Before starting keygen, we fetch the complete participant list from + * the backend. This is critical because the local session only contains ourselves + * as a participant initially. Without the full list, TSS keygen will fail. + * + * This matches Electron's behavior in checkAndTriggerKeygen() which calls + * getSessionStatus() to get all participants before starting keygen. */ suspend fun startKeygenAsInitiator( sessionId: String, @@ -1207,18 +1324,56 @@ class TssRepository @Inject constructor( return@coroutineScope Result.failure(Exception("No active session")) } - android.util.Log.d("TssRepository", "Starting keygen as initiator: sessionId=$sessionId, partyIndex=${session.participants.firstOrNull()?.partyIndex}") + android.util.Log.d("TssRepository", "Starting keygen as initiator: sessionId=$sessionId") - val myPartyIndex = session.participants.firstOrNull()?.partyIndex ?: 0 + // CRITICAL: Fetch complete participant list from backend before starting keygen + // This matches Electron's checkAndTriggerKeygen() behavior + val statusResult = getSessionStatus(sessionId) + if (statusResult.isFailure) { + android.util.Log.e("TssRepository", "Failed to get session status: ${statusResult.exceptionOrNull()?.message}") + return@coroutineScope Result.failure(statusResult.exceptionOrNull()!!) + } - // Start TSS keygen + val sessionStatus = statusResult.getOrThrow() + android.util.Log.d("TssRepository", "Got session status: ${sessionStatus.participants.size} participants") + + // Build complete participant list from backend response + val updatedParticipants = mutableListOf() + for (p in sessionStatus.participants) { + val name = if (p.partyId == partyId) { + session.participants.firstOrNull()?.name ?: "我" + } else { + "参与方 ${p.partyIndex + 1}" + } + updatedParticipants.add(Participant(p.partyId, p.partyIndex, name)) + } + + // Find my party index from the updated list + val myInfo = updatedParticipants.find { it.partyId == partyId } + val myPartyIndex = myInfo?.partyIndex ?: session.participants.firstOrNull()?.partyIndex ?: 0 + + android.util.Log.d("TssRepository", "Updated participants: ${updatedParticipants.map { "${it.partyId.take(8)}:${it.partyIndex}" }}") + android.util.Log.d("TssRepository", "My party index: $myPartyIndex") + + // Update the session with complete participant list + _currentSession.value = session.copy( + participants = updatedParticipants, + thresholdT = sessionStatus.thresholdT, + thresholdN = sessionStatus.thresholdN + ) + + // Use thresholds from backend (they may be more accurate) + val actualThresholdT = sessionStatus.thresholdT + val actualThresholdN = sessionStatus.thresholdN + + // Start TSS keygen with complete participant list val startResult = tssNativeBridge.startKeygen( sessionId = sessionId, partyId = partyId, partyIndex = myPartyIndex, - thresholdT = thresholdT, - thresholdN = thresholdN, - participants = session.participants, + thresholdT = actualThresholdT, + thresholdN = actualThresholdN, + participants = updatedParticipants, password = password ) @@ -1242,15 +1397,15 @@ class TssRepository @Inject constructor( // Derive address from public key val publicKeyBytes = Base64.decode(result.publicKey, Base64.NO_WRAP) - val address = AddressUtils.deriveKavaAddress(publicKeyBytes) + val address = AddressUtils.deriveEvmAddress(publicKeyBytes) - // Save share record + // Save share record (use actual thresholds from backend) val shareEntity = ShareRecordEntity( sessionId = sessionId, publicKey = result.publicKey, encryptedShare = result.encryptedShare, - thresholdT = thresholdT, - thresholdN = thresholdN, + thresholdT = actualThresholdT, + thresholdN = actualThresholdN, partyIndex = myPartyIndex, address = address ) @@ -1828,6 +1983,30 @@ data class JoinSignViaGrpcResult( val shareId: Long ) +/** + * Session status response from getSessionStatus API + * Matches Electron's GetSessionStatusResponse + */ +data class SessionStatusResponse( + val sessionId: String, + val status: String, + val thresholdT: Int, + val thresholdN: Int, + val completedParties: Int, + val totalParties: Int, + val sessionType: String, + val participants: List +) + +/** + * Participant status info with party_index + */ +data class ParticipantStatusInfo( + val partyId: String, + val partyIndex: Int, + val status: String +) + private fun ShareRecordEntity.toShareRecord() = ShareRecord( id = id, sessionId = sessionId, diff --git a/backend/mpc-system/services/service-party-android/app/src/main/java/com/durian/tssparty/presentation/viewmodel/MainViewModel.kt b/backend/mpc-system/services/service-party-android/app/src/main/java/com/durian/tssparty/presentation/viewmodel/MainViewModel.kt index e62b2882..0924ee63 100644 --- a/backend/mpc-system/services/service-party-android/app/src/main/java/com/durian/tssparty/presentation/viewmodel/MainViewModel.kt +++ b/backend/mpc-system/services/service-party-android/app/src/main/java/com/durian/tssparty/presentation/viewmodel/MainViewModel.kt @@ -5,6 +5,7 @@ import androidx.lifecycle.viewModelScope import com.durian.tssparty.data.repository.JoinKeygenViaGrpcResult import com.durian.tssparty.data.repository.TssRepository import com.durian.tssparty.domain.model.* +import com.durian.tssparty.util.AddressUtils import com.durian.tssparty.util.TransactionUtils import dagger.hilt.android.lifecycle.HiltViewModel import kotlinx.coroutines.delay @@ -898,7 +899,23 @@ class MainViewModel @Inject constructor( val balances: StateFlow> = _balances.asStateFlow() /** - * Fetch balance for a wallet address + * Fetch balance for a wallet using share data (handles both EVM and Cosmos address formats) + */ + fun fetchBalanceForShare(share: ShareRecord) { + viewModelScope.launch { + val rpcUrl = _settings.value.kavaRpcUrl + // Ensure we use EVM address format for RPC calls + val evmAddress = AddressUtils.getEvmAddress(share.address, share.publicKey) + val result = repository.getBalance(evmAddress, rpcUrl) + result.onSuccess { balance -> + // Store balance with original address as key (for UI lookup) + _balances.update { it + (share.address to balance) } + } + } + } + + /** + * Fetch balance for a wallet address (for already-EVM addresses) */ fun fetchBalance(address: String) { viewModelScope.launch { @@ -916,7 +933,7 @@ class MainViewModel @Inject constructor( fun fetchAllBalances() { viewModelScope.launch { shares.value.forEach { share -> - fetchBalance(share.address) + fetchBalanceForShare(share) } } } diff --git a/backend/mpc-system/services/service-party-android/app/src/main/java/com/durian/tssparty/util/AddressUtils.kt b/backend/mpc-system/services/service-party-android/app/src/main/java/com/durian/tssparty/util/AddressUtils.kt index a25ef8c6..10183acd 100644 --- a/backend/mpc-system/services/service-party-android/app/src/main/java/com/durian/tssparty/util/AddressUtils.kt +++ b/backend/mpc-system/services/service-party-android/app/src/main/java/com/durian/tssparty/util/AddressUtils.kt @@ -30,6 +30,27 @@ object AddressUtils { return Bech32.encode("kava", convertBits(ripemd160, 8, 5, true)) } + /** + * Check if address is in EVM format (0x...) + */ + fun isEvmAddress(address: String): Boolean { + return address.startsWith("0x") && address.length == 42 + } + + /** + * Get EVM address - either returns the address if already EVM format, + * or derives it from the public key + */ + fun getEvmAddress(address: String, publicKeyBase64: String): String { + return if (isEvmAddress(address)) { + address + } else { + // Derive EVM address from public key + val publicKeyBytes = android.util.Base64.decode(publicKeyBase64, android.util.Base64.NO_WRAP) + deriveEvmAddress(publicKeyBytes) + } + } + /** * Derive EVM address from public key (for Kava EVM compatibility) */ diff --git a/backend/mpc-system/services/session-coordinator/application/use_cases/join_session.go b/backend/mpc-system/services/session-coordinator/application/use_cases/join_session.go index 808a520d..e1423953 100644 --- a/backend/mpc-system/services/session-coordinator/application/use_cases/join_session.go +++ b/backend/mpc-system/services/session-coordinator/application/use_cases/join_session.go @@ -2,6 +2,8 @@ package use_cases import ( "context" + "errors" + "fmt" "time" "github.com/google/uuid" @@ -16,6 +18,9 @@ import ( "go.uber.org/zap" ) +// Maximum retries for optimistic lock conflicts during join session +const joinSessionMaxRetries = 3 + // JoinSessionMessageRouterClient defines the interface for publishing session events via gRPC type JoinSessionMessageRouterClient interface { PublishSessionStarted( @@ -54,11 +59,35 @@ func NewJoinSessionUseCase( } } -// Execute executes the join session use case +// Execute executes the join session use case with retry logic for optimistic lock conflicts func (uc *JoinSessionUseCase) Execute( ctx context.Context, inputData input.JoinSessionInput, ) (*input.JoinSessionOutput, error) { + return uc.executeWithRetry(ctx, inputData, 0) +} + +// executeWithRetry executes the join session with retry logic for optimistic lock conflicts +func (uc *JoinSessionUseCase) executeWithRetry( + ctx context.Context, + inputData input.JoinSessionInput, + retry int, +) (*input.JoinSessionOutput, error) { + if retry >= joinSessionMaxRetries { + logger.Error("max retries exceeded for optimistic lock in join session", + zap.String("session_id", inputData.SessionID.String()), + zap.String("party_id", inputData.PartyID), + zap.Int("retry_count", retry)) + return nil, fmt.Errorf("max retries exceeded: %w", entities.ErrOptimisticLockConflict) + } + + if retry > 0 { + logger.Info("retrying join session due to optimistic lock conflict", + zap.String("session_id", inputData.SessionID.String()), + zap.String("party_id", inputData.PartyID), + zap.Int("retry_attempt", retry)) + } + // Debug: log token info tokenLen := len(inputData.JoinToken) tokenPreview := "" @@ -102,7 +131,7 @@ func (uc *JoinSessionUseCase) Execute( return nil, err } - // 3. Load session + // 3. Load session (fresh read for each retry attempt) session, err := uc.sessionRepo.FindByUUID(ctx, sessionID) if err != nil { return nil, err @@ -229,8 +258,16 @@ func (uc *JoinSessionUseCase) Execute( } } - // 8. Save updated session + // 8. Save updated session (with optimistic lock retry) if err := uc.sessionRepo.Update(ctx, session); err != nil { + // Check if this is an optimistic lock conflict - if so, retry + if errors.Is(err, entities.ErrOptimisticLockConflict) { + logger.Warn("optimistic lock conflict detected in join session, retrying", + zap.String("session_id", session.ID.String()), + zap.String("party_id", inputData.PartyID), + zap.Int("retry_attempt", retry+1)) + return uc.executeWithRetry(ctx, inputData, retry+1) + } return nil, err }