fix(android): change address format from Cosmos to EVM and fix balance query

Changes:
- Change address derivation from deriveKavaAddress to deriveEvmAddress
  in TssRepository.kt (3 locations)
- Add AddressUtils.isEvmAddress() and getEvmAddress() helper methods
  to handle both old Cosmos and new EVM address formats
- Fix balance query for old wallets by deriving EVM address from
  public key when needed (MainViewModel.fetchBalanceForShare)
- Add retry logic for optimistic lock conflicts in join_session.go
  to prevent party_index collision during concurrent joins

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
hailin 2026-01-01 07:48:52 -08:00
parent 444b720f8d
commit 136a5ba851
4 changed files with 281 additions and 27 deletions

View File

@ -349,6 +349,77 @@ class TssRepository @Inject constructor(
}
}
/**
* Get session status with full participant list
* Calls account-service API: GET /api/v1/co-managed/sessions/{sessionId}
* This is critical for getting the complete participant list before starting keygen
*/
suspend fun getSessionStatus(sessionId: String): Result<SessionStatusResponse> {
return withContext(Dispatchers.IO) {
try {
val request = okhttp3.Request.Builder()
.url("$accountServiceUrl/api/v1/co-managed/sessions/$sessionId")
.get()
.build()
android.util.Log.d("TssRepository", "Getting session status: $sessionId")
val response = httpClient.newCall(request).execute()
val responseBody = response.body?.string()
?: return@withContext Result.failure(Exception("空响应"))
android.util.Log.d("TssRepository", "Session status response: $responseBody")
if (!response.isSuccessful) {
val errorJson = try {
com.google.gson.JsonParser.parseString(responseBody).asJsonObject
} catch (e: Exception) { null }
val errorMsg = errorJson?.get("message")?.asString
?: errorJson?.get("error")?.asString
?: "HTTP ${response.code}"
return@withContext Result.failure(Exception(errorMsg))
}
// Parse response
val json = com.google.gson.JsonParser.parseString(responseBody).asJsonObject
val status = json.get("status")?.asString ?: ""
val thresholdT = json.get("threshold_t")?.asInt ?: 2
val thresholdN = json.get("threshold_n")?.asInt ?: json.get("total_parties")?.asInt ?: 3
val completedParties = json.get("completed_parties")?.asInt ?: 0
val totalParties = json.get("total_parties")?.asInt ?: thresholdN
val sessionType = json.get("session_type")?.asString ?: ""
// Parse participants array with party_index
val participantsArray = json.get("participants")?.asJsonArray
val participants = mutableListOf<ParticipantStatusInfo>()
participantsArray?.forEach { elem ->
val p = elem.asJsonObject
participants.add(ParticipantStatusInfo(
partyId = p.get("party_id")?.asString ?: "",
partyIndex = p.get("party_index")?.asInt ?: 0,
status = p.get("status")?.asString ?: ""
))
}
android.util.Log.d("TssRepository", "Session status: status=$status, participants=${participants.size}, threshold=$thresholdT-of-$thresholdN")
Result.success(SessionStatusResponse(
sessionId = sessionId,
status = status,
thresholdT = thresholdT,
thresholdN = thresholdN,
completedParties = completedParties,
totalParties = totalParties,
sessionType = sessionType,
participants = participants
))
} catch (e: Exception) {
android.util.Log.e("TssRepository", "Get session status failed", e)
Result.failure(e)
}
}
}
/**
* Validate an invite code and get session info
* Calls account-service API: GET /api/v1/co-managed/sessions/by-invite-code/{inviteCode}
@ -557,14 +628,53 @@ class TssRepository @Inject constructor(
android.util.Log.d("TssRepository", "Executing keygen as joiner: sessionId=$sessionId, partyIndex=$partyIndex")
// Start TSS keygen
// CRITICAL: Fetch complete participant list from backend before starting keygen
// This ensures we have all participants even if some joined after we did
val statusResult = getSessionStatus(sessionId)
if (statusResult.isFailure) {
android.util.Log.e("TssRepository", "Failed to get session status before keygen", statusResult.exceptionOrNull())
return@coroutineScope Result.failure(statusResult.exceptionOrNull()!!)
}
val sessionStatus = statusResult.getOrThrow()
// Build complete participant list from backend response
val updatedParticipants = mutableListOf<Participant>()
for (p in sessionStatus.participants) {
val name = if (p.partyId == partyId) {
session.participants.find { it.partyId == partyId }?.name ?: ""
} else {
session.participants.find { it.partyId == p.partyId }?.name ?: "参与方 ${p.partyIndex + 1}"
}
updatedParticipants.add(Participant(p.partyId, p.partyIndex, name))
}
// Find my party index from the updated list (may differ from initial partyIndex if backend assigned differently)
val myInfo = updatedParticipants.find { it.partyId == partyId }
val actualPartyIndex = myInfo?.partyIndex ?: partyIndex
android.util.Log.d("TssRepository", "Joiner updated participants: ${updatedParticipants.map { "${it.partyId.take(8)}:${it.partyIndex}" }}")
android.util.Log.d("TssRepository", "My actual party index: $actualPartyIndex (original: $partyIndex)")
// Update the session with complete participant list
_currentSession.value = session.copy(
participants = updatedParticipants,
thresholdT = sessionStatus.thresholdT,
thresholdN = sessionStatus.thresholdN
)
// Use thresholds from backend (they may be more accurate)
val actualThresholdT = sessionStatus.thresholdT
val actualThresholdN = sessionStatus.thresholdN
// Start TSS keygen with complete participant list
val startResult = tssNativeBridge.startKeygen(
sessionId = sessionId,
partyId = partyId,
partyIndex = partyIndex,
thresholdT = thresholdT,
thresholdN = thresholdN,
participants = session.participants,
partyIndex = actualPartyIndex,
thresholdT = actualThresholdT,
thresholdN = actualThresholdN,
participants = updatedParticipants,
password = password
)
@ -588,16 +698,16 @@ class TssRepository @Inject constructor(
// Derive address from public key
val publicKeyBytes = Base64.decode(result.publicKey, Base64.NO_WRAP)
val address = AddressUtils.deriveKavaAddress(publicKeyBytes)
val address = AddressUtils.deriveEvmAddress(publicKeyBytes)
// Save share record
// Save share record (use actual thresholds and party index from backend)
val shareEntity = ShareRecordEntity(
sessionId = sessionId,
publicKey = result.publicKey,
encryptedShare = result.encryptedShare,
thresholdT = thresholdT,
thresholdN = thresholdN,
partyIndex = partyIndex,
thresholdT = actualThresholdT,
thresholdN = actualThresholdN,
partyIndex = actualPartyIndex,
address = address
)
val id = shareRecordDao.insertShare(shareEntity)
@ -607,7 +717,7 @@ class TssRepository @Inject constructor(
_sessionStatus.value = SessionStatus.COMPLETED
android.util.Log.d("TssRepository", "Keygen as joiner completed: address=$address")
android.util.Log.d("TssRepository", "Keygen as joiner completed: address=$address, partyIndex=$actualPartyIndex")
Result.success(shareEntity.copy(id = id).toShareRecord())
@ -872,7 +982,7 @@ class TssRepository @Inject constructor(
// Derive address from public key
val publicKeyBytes = Base64.decode(result.publicKey, Base64.NO_WRAP)
val address = AddressUtils.deriveKavaAddress(publicKeyBytes)
val address = AddressUtils.deriveEvmAddress(publicKeyBytes)
// Save share record
val shareEntity = ShareRecordEntity(
@ -1194,6 +1304,13 @@ class TssRepository @Inject constructor(
/**
* Start keygen as initiator (called when session_started event is received)
*
* IMPORTANT: Before starting keygen, we fetch the complete participant list from
* the backend. This is critical because the local session only contains ourselves
* as a participant initially. Without the full list, TSS keygen will fail.
*
* This matches Electron's behavior in checkAndTriggerKeygen() which calls
* getSessionStatus() to get all participants before starting keygen.
*/
suspend fun startKeygenAsInitiator(
sessionId: String,
@ -1207,18 +1324,56 @@ class TssRepository @Inject constructor(
return@coroutineScope Result.failure(Exception("No active session"))
}
android.util.Log.d("TssRepository", "Starting keygen as initiator: sessionId=$sessionId, partyIndex=${session.participants.firstOrNull()?.partyIndex}")
android.util.Log.d("TssRepository", "Starting keygen as initiator: sessionId=$sessionId")
val myPartyIndex = session.participants.firstOrNull()?.partyIndex ?: 0
// CRITICAL: Fetch complete participant list from backend before starting keygen
// This matches Electron's checkAndTriggerKeygen() behavior
val statusResult = getSessionStatus(sessionId)
if (statusResult.isFailure) {
android.util.Log.e("TssRepository", "Failed to get session status: ${statusResult.exceptionOrNull()?.message}")
return@coroutineScope Result.failure(statusResult.exceptionOrNull()!!)
}
// Start TSS keygen
val sessionStatus = statusResult.getOrThrow()
android.util.Log.d("TssRepository", "Got session status: ${sessionStatus.participants.size} participants")
// Build complete participant list from backend response
val updatedParticipants = mutableListOf<Participant>()
for (p in sessionStatus.participants) {
val name = if (p.partyId == partyId) {
session.participants.firstOrNull()?.name ?: ""
} else {
"参与方 ${p.partyIndex + 1}"
}
updatedParticipants.add(Participant(p.partyId, p.partyIndex, name))
}
// Find my party index from the updated list
val myInfo = updatedParticipants.find { it.partyId == partyId }
val myPartyIndex = myInfo?.partyIndex ?: session.participants.firstOrNull()?.partyIndex ?: 0
android.util.Log.d("TssRepository", "Updated participants: ${updatedParticipants.map { "${it.partyId.take(8)}:${it.partyIndex}" }}")
android.util.Log.d("TssRepository", "My party index: $myPartyIndex")
// Update the session with complete participant list
_currentSession.value = session.copy(
participants = updatedParticipants,
thresholdT = sessionStatus.thresholdT,
thresholdN = sessionStatus.thresholdN
)
// Use thresholds from backend (they may be more accurate)
val actualThresholdT = sessionStatus.thresholdT
val actualThresholdN = sessionStatus.thresholdN
// Start TSS keygen with complete participant list
val startResult = tssNativeBridge.startKeygen(
sessionId = sessionId,
partyId = partyId,
partyIndex = myPartyIndex,
thresholdT = thresholdT,
thresholdN = thresholdN,
participants = session.participants,
thresholdT = actualThresholdT,
thresholdN = actualThresholdN,
participants = updatedParticipants,
password = password
)
@ -1242,15 +1397,15 @@ class TssRepository @Inject constructor(
// Derive address from public key
val publicKeyBytes = Base64.decode(result.publicKey, Base64.NO_WRAP)
val address = AddressUtils.deriveKavaAddress(publicKeyBytes)
val address = AddressUtils.deriveEvmAddress(publicKeyBytes)
// Save share record
// Save share record (use actual thresholds from backend)
val shareEntity = ShareRecordEntity(
sessionId = sessionId,
publicKey = result.publicKey,
encryptedShare = result.encryptedShare,
thresholdT = thresholdT,
thresholdN = thresholdN,
thresholdT = actualThresholdT,
thresholdN = actualThresholdN,
partyIndex = myPartyIndex,
address = address
)
@ -1828,6 +1983,30 @@ data class JoinSignViaGrpcResult(
val shareId: Long
)
/**
* Session status response from getSessionStatus API
* Matches Electron's GetSessionStatusResponse
*/
data class SessionStatusResponse(
val sessionId: String,
val status: String,
val thresholdT: Int,
val thresholdN: Int,
val completedParties: Int,
val totalParties: Int,
val sessionType: String,
val participants: List<ParticipantStatusInfo>
)
/**
* Participant status info with party_index
*/
data class ParticipantStatusInfo(
val partyId: String,
val partyIndex: Int,
val status: String
)
private fun ShareRecordEntity.toShareRecord() = ShareRecord(
id = id,
sessionId = sessionId,

View File

@ -5,6 +5,7 @@ import androidx.lifecycle.viewModelScope
import com.durian.tssparty.data.repository.JoinKeygenViaGrpcResult
import com.durian.tssparty.data.repository.TssRepository
import com.durian.tssparty.domain.model.*
import com.durian.tssparty.util.AddressUtils
import com.durian.tssparty.util.TransactionUtils
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.delay
@ -898,7 +899,23 @@ class MainViewModel @Inject constructor(
val balances: StateFlow<Map<String, String>> = _balances.asStateFlow()
/**
* Fetch balance for a wallet address
* Fetch balance for a wallet using share data (handles both EVM and Cosmos address formats)
*/
fun fetchBalanceForShare(share: ShareRecord) {
viewModelScope.launch {
val rpcUrl = _settings.value.kavaRpcUrl
// Ensure we use EVM address format for RPC calls
val evmAddress = AddressUtils.getEvmAddress(share.address, share.publicKey)
val result = repository.getBalance(evmAddress, rpcUrl)
result.onSuccess { balance ->
// Store balance with original address as key (for UI lookup)
_balances.update { it + (share.address to balance) }
}
}
}
/**
* Fetch balance for a wallet address (for already-EVM addresses)
*/
fun fetchBalance(address: String) {
viewModelScope.launch {
@ -916,7 +933,7 @@ class MainViewModel @Inject constructor(
fun fetchAllBalances() {
viewModelScope.launch {
shares.value.forEach { share ->
fetchBalance(share.address)
fetchBalanceForShare(share)
}
}
}

View File

@ -30,6 +30,27 @@ object AddressUtils {
return Bech32.encode("kava", convertBits(ripemd160, 8, 5, true))
}
/**
* Check if address is in EVM format (0x...)
*/
fun isEvmAddress(address: String): Boolean {
return address.startsWith("0x") && address.length == 42
}
/**
* Get EVM address - either returns the address if already EVM format,
* or derives it from the public key
*/
fun getEvmAddress(address: String, publicKeyBase64: String): String {
return if (isEvmAddress(address)) {
address
} else {
// Derive EVM address from public key
val publicKeyBytes = android.util.Base64.decode(publicKeyBase64, android.util.Base64.NO_WRAP)
deriveEvmAddress(publicKeyBytes)
}
}
/**
* Derive EVM address from public key (for Kava EVM compatibility)
*/

View File

@ -2,6 +2,8 @@ package use_cases
import (
"context"
"errors"
"fmt"
"time"
"github.com/google/uuid"
@ -16,6 +18,9 @@ import (
"go.uber.org/zap"
)
// Maximum retries for optimistic lock conflicts during join session
const joinSessionMaxRetries = 3
// JoinSessionMessageRouterClient defines the interface for publishing session events via gRPC
type JoinSessionMessageRouterClient interface {
PublishSessionStarted(
@ -54,11 +59,35 @@ func NewJoinSessionUseCase(
}
}
// Execute executes the join session use case
// Execute executes the join session use case with retry logic for optimistic lock conflicts
func (uc *JoinSessionUseCase) Execute(
ctx context.Context,
inputData input.JoinSessionInput,
) (*input.JoinSessionOutput, error) {
return uc.executeWithRetry(ctx, inputData, 0)
}
// executeWithRetry executes the join session with retry logic for optimistic lock conflicts
func (uc *JoinSessionUseCase) executeWithRetry(
ctx context.Context,
inputData input.JoinSessionInput,
retry int,
) (*input.JoinSessionOutput, error) {
if retry >= joinSessionMaxRetries {
logger.Error("max retries exceeded for optimistic lock in join session",
zap.String("session_id", inputData.SessionID.String()),
zap.String("party_id", inputData.PartyID),
zap.Int("retry_count", retry))
return nil, fmt.Errorf("max retries exceeded: %w", entities.ErrOptimisticLockConflict)
}
if retry > 0 {
logger.Info("retrying join session due to optimistic lock conflict",
zap.String("session_id", inputData.SessionID.String()),
zap.String("party_id", inputData.PartyID),
zap.Int("retry_attempt", retry))
}
// Debug: log token info
tokenLen := len(inputData.JoinToken)
tokenPreview := ""
@ -102,7 +131,7 @@ func (uc *JoinSessionUseCase) Execute(
return nil, err
}
// 3. Load session
// 3. Load session (fresh read for each retry attempt)
session, err := uc.sessionRepo.FindByUUID(ctx, sessionID)
if err != nil {
return nil, err
@ -229,8 +258,16 @@ func (uc *JoinSessionUseCase) Execute(
}
}
// 8. Save updated session
// 8. Save updated session (with optimistic lock retry)
if err := uc.sessionRepo.Update(ctx, session); err != nil {
// Check if this is an optimistic lock conflict - if so, retry
if errors.Is(err, entities.ErrOptimisticLockConflict) {
logger.Warn("optimistic lock conflict detected in join session, retrying",
zap.String("session_id", session.ID.String()),
zap.String("party_id", inputData.PartyID),
zap.Int("retry_attempt", retry+1))
return uc.executeWithRetry(ctx, inputData, retry+1)
}
return nil, err
}