diff --git a/app/src/main/java/com/notcvnt/rknhardering/ScanExecutionContext.kt b/app/src/main/java/com/notcvnt/rknhardering/ScanExecutionContext.kt new file mode 100644 index 0000000..72c95ea --- /dev/null +++ b/app/src/main/java/com/notcvnt/rknhardering/ScanExecutionContext.kt @@ -0,0 +1,104 @@ +package com.notcvnt.rknhardering + +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.asContextElement +import java.net.DatagramSocket +import java.net.Socket +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicLong +import kotlin.coroutines.CoroutineContext + +private val scanExecutionContextThreadLocal = ThreadLocal() + +data class ScanExecutionContext( + val scanId: Long = 0L, + val cancellationSignal: ScanCancellationSignal = ScanCancellationSignal(), +) { + fun asCoroutineContext(): CoroutineContext = scanExecutionContextThreadLocal.asContextElement(this) + + fun throwIfCancelled(cause: Throwable? = null) { + cancellationSignal.throwIfCancelled(cause) + } + + companion object { + fun currentOrDefault(): ScanExecutionContext { + return scanExecutionContextThreadLocal.get() ?: ScanExecutionContext() + } + } +} + +class ScanCancellationSignal { + private val cancelled = AtomicBoolean(false) + private val nextRegistrationId = AtomicLong(1L) + private val callbacks = ConcurrentHashMap Unit>() + + fun isCancelled(): Boolean = cancelled.get() + + fun throwIfCancelled(cause: Throwable? = null) { + if (!isCancelled()) return + throw CancellationException("Scan cancelled").also { cancellation -> + cause?.let(cancellation::initCause) + } + } + + fun register(callback: () -> Unit): Registration { + if (isCancelled()) { + callback() + return Registration.NO_OP + } + + val id = nextRegistrationId.getAndIncrement() + callbacks[id] = callback + if (isCancelled()) { + callbacks.remove(id)?.invoke() + return Registration.NO_OP + } + + return Registration(this, id) + } + + fun cancel() { + if (!cancelled.compareAndSet(false, true)) return + val pending = callbacks.entries.toList() + callbacks.clear() + pending.forEach { (_, callback) -> + runCatching(callback) + } + } + + private fun unregister(id: Long) { + callbacks.remove(id) + } + + class Registration internal constructor( + private val signal: ScanCancellationSignal?, + private val id: Long?, + ) { + fun dispose() { + val activeSignal = signal ?: return + val activeId = id ?: return + activeSignal.unregister(activeId) + } + + companion object { + internal val NO_OP = Registration(signal = null, id = null) + } + } +} + +fun rethrowIfCancellation( + error: Throwable, + executionContext: ScanExecutionContext = ScanExecutionContext.currentOrDefault(), +) { + if (error is CancellationException) throw error + executionContext.throwIfCancelled(error) +} + +fun ScanCancellationSignal.registerSocket(socket: Socket): ScanCancellationSignal.Registration { + return register { runCatching { socket.close() } } +} + +fun ScanCancellationSignal.registerDatagramSocket(socket: DatagramSocket): ScanCancellationSignal.Registration { + return register { runCatching { socket.close() } } +} diff --git a/app/src/main/java/com/notcvnt/rknhardering/checker/IpComparisonChecker.kt b/app/src/main/java/com/notcvnt/rknhardering/checker/IpComparisonChecker.kt index 675c6e4..5ef965b 100644 --- a/app/src/main/java/com/notcvnt/rknhardering/checker/IpComparisonChecker.kt +++ b/app/src/main/java/com/notcvnt/rknhardering/checker/IpComparisonChecker.kt @@ -65,13 +65,13 @@ object IpComparisonChecker { ), EndpointSpec( label = "ifconfig.me IPv4", - url = "https://ipv4.ifconfig.me/ip", + url = "https://ifconfig.me/ip", scope = IpCheckerScope.NON_RU, addressFamily = Inet4Address::class.java, ), EndpointSpec( label = "ifconfig.me IPv6", - url = "https://ipv6.ifconfig.me/ip", + url = "https://ifconfig.me/ip", scope = IpCheckerScope.NON_RU, addressFamily = Inet6Address::class.java, ), diff --git a/app/src/main/java/com/notcvnt/rknhardering/model/CheckResult.kt b/app/src/main/java/com/notcvnt/rknhardering/model/CheckResult.kt index 470a68e..2686fdb 100644 --- a/app/src/main/java/com/notcvnt/rknhardering/model/CheckResult.kt +++ b/app/src/main/java/com/notcvnt/rknhardering/model/CheckResult.kt @@ -231,6 +231,9 @@ data class CdnPullingResponse( val targetLabel: String, val url: String, val ip: String? = null, + val ipv4: String? = null, + val ipv6: String? = null, + val ipv4Unavailable: Boolean = false, val importantFields: Map = emptyMap(), val rawBody: String? = null, val error: String? = null, diff --git a/app/src/main/java/com/notcvnt/rknhardering/probe/PublicIpClient.kt b/app/src/main/java/com/notcvnt/rknhardering/probe/PublicIpClient.kt index f320328..94fe3cc 100644 --- a/app/src/main/java/com/notcvnt/rknhardering/probe/PublicIpClient.kt +++ b/app/src/main/java/com/notcvnt/rknhardering/probe/PublicIpClient.kt @@ -196,6 +196,8 @@ object PublicIpClient { config = request.config, proxy = request.proxy, binding = request.binding, + addressFamily = request.addressFamily, + cancellationSignal = request.cancellationSignal, ) TransportPolicy.NATIVE_CURL_ONLY -> { if (!NativeCurlHttpClient.canExecute(request)) { diff --git a/app/src/test/java/com/notcvnt/rknhardering/MainActivityUiRenderingTest.kt b/app/src/test/java/com/notcvnt/rknhardering/MainActivityUiRenderingTest.kt new file mode 100644 index 0000000..164e2f8 --- /dev/null +++ b/app/src/test/java/com/notcvnt/rknhardering/MainActivityUiRenderingTest.kt @@ -0,0 +1,136 @@ +package com.notcvnt.rknhardering + +import android.view.View +import android.view.ViewGroup +import android.widget.ImageView +import android.widget.LinearLayout +import android.widget.TextView +import androidx.test.core.app.ApplicationProvider +import com.google.android.material.card.MaterialCardView +import com.notcvnt.rknhardering.model.CategoryResult +import com.notcvnt.rknhardering.model.CdnPullingResponse +import com.notcvnt.rknhardering.model.Finding +import com.notcvnt.rknhardering.model.IpCheckerGroupResult +import com.notcvnt.rknhardering.model.IpCheckerResponse +import com.notcvnt.rknhardering.model.IpCheckerScope +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.Robolectric +import org.robolectric.RobolectricTestRunner + +@RunWith(RobolectricTestRunner::class) +class MainActivityUiRenderingTest { + + @Test + fun `ip comparison response view hides raw error details`() { + val activity = Robolectric.buildActivity(MainActivity::class.java).setup().get() + val response = IpCheckerResponse( + label = "ip.sb IPv6", + url = "https://api-ipv6.ip.sb/ip", + scope = IpCheckerScope.NON_RU, + error = "OkHttp failed after 3 attempts: Could not connect to server", + ) + + val view = invokePrivate(activity, "createIpCheckerResponseView", response, false) + val text = collectText(view) + + assertTrueContains(text, "ip.sb IPv6") + assertTrueContains(text, "https://api-ipv6.ip.sb/ip") + assertTrueContains(text, activity.getString(R.string.main_card_status_error)) + assertFalse(text.contains("OkHttp failed after 3 attempts")) + assertFalse(text.contains("Could not connect to server")) + } + + @Test + fun `ip comparison response view hides ignored ipv6 diagnostics`() { + val activity = Robolectric.buildActivity(MainActivity::class.java).setup().get() + val response = IpCheckerResponse( + label = "ip.sb IPv6", + url = "https://api-ipv6.ip.sb/ip", + scope = IpCheckerScope.NON_RU, + error = "native curl failed", + ignoredIpv6Error = true, + ) + + val view = invokePrivate(activity, "createIpCheckerResponseView", response, false) + val text = collectText(view) + + assertFalse(text.contains(activity.getString(R.string.main_ipv6_error_ignored).trim())) + assertFalse(text.contains("native curl failed")) + } + + @Test + fun `cdn pulling response view hides raw error details`() { + val activity = Robolectric.buildActivity(MainActivity::class.java).setup().get() + val response = CdnPullingResponse( + targetLabel = "meduza.io", + url = "https://meduza.io", + error = "SSLHandshakeException: certificate path validation failed", + ) + + val view = invokePrivate(activity, "createCdnPullingResponseView", response, false) + val text = collectText(view) + + assertTrueContains(text, "meduza.io") + assertTrueContains(text, "https://meduza.io") + assertTrueContains(text, activity.getString(R.string.main_card_status_error)) + assertFalse(text.contains("SSLHandshakeException")) + assertFalse(text.contains("certificate path validation failed")) + } + + @Test + fun `display category keeps error status and hides error finding text`() { + val activity = Robolectric.buildActivity(MainActivity::class.java).setup().get() + val category = CategoryResult( + name = "direct", + detected = false, + findings = listOf(Finding("Socket timeout to 203.0.113.64", isError = true)), + ) + val card = activity.findViewById(R.id.cardIndirect) + val icon = activity.findViewById(R.id.iconIndirect) + val status = activity.findViewById(R.id.statusIndirect) + val findings = activity.findViewById(R.id.findingsIndirect) + + invokePrivate( + activity, + "displayCategory", + category, + card, + icon, + status, + findings, + false, + ) + + assertEquals(activity.getString(R.string.main_card_status_error), status.text.toString()) + assertFalse(collectText(findings).contains("Socket timeout to 203.0.113.64")) + } + + private fun collectText(view: View): String { + if (view is TextView) return view.text.toString() + if (view !is ViewGroup) return "" + return buildString { + for (index in 0 until view.childCount) { + val childText = collectText(view.getChildAt(index)) + if (childText.isBlank()) continue + if (isNotBlank()) append('\n') + append(childText) + } + } + } + + private fun assertTrueContains(text: String, expected: String) { + assertFalse("Expected text to contain <$expected>, got <$text>", !text.contains(expected)) + } + + @Suppress("UNCHECKED_CAST") + private fun invokePrivate(target: Any, name: String, vararg args: Any?): T { + val method = target::class.java.declaredMethods.first { candidate -> + candidate.name == name && candidate.parameterTypes.size == args.size + } + method.isAccessible = true + return method.invoke(target, *args) as T + } +} diff --git a/app/src/test/java/com/notcvnt/rknhardering/checker/CdnPullingCheckerTest.kt b/app/src/test/java/com/notcvnt/rknhardering/checker/CdnPullingCheckerTest.kt index c357792..7fba323 100644 --- a/app/src/test/java/com/notcvnt/rknhardering/checker/CdnPullingCheckerTest.kt +++ b/app/src/test/java/com/notcvnt/rknhardering/checker/CdnPullingCheckerTest.kt @@ -103,7 +103,7 @@ class CdnPullingCheckerTest { resolverConfig = DnsResolverConfig.system(), maxAttempts = 3, retryDelayMs = 0, - ) { _, _, _ -> + ) { _, _, _, _ -> attempts += 1 if (attempts < 3) { Result.failure(IOException("timeout")) @@ -130,7 +130,7 @@ class CdnPullingCheckerTest { resolverConfig = DnsResolverConfig.system(), maxAttempts = 3, retryDelayMs = 0, - ) { _, _, _ -> + ) { _, _, _, _ -> attempts += 1 Result.failure(error) } diff --git a/app/src/test/java/com/notcvnt/rknhardering/checker/IpComparisonCheckerTest.kt b/app/src/test/java/com/notcvnt/rknhardering/checker/IpComparisonCheckerTest.kt index 13ab567..91057a7 100644 --- a/app/src/test/java/com/notcvnt/rknhardering/checker/IpComparisonCheckerTest.kt +++ b/app/src/test/java/com/notcvnt/rknhardering/checker/IpComparisonCheckerTest.kt @@ -199,7 +199,7 @@ class IpComparisonCheckerTest { resolverConfig = DnsResolverConfig.system(), maxAttempts = 3, retryDelayMs = 0, - ) { _, _, _ -> + ) { _, _, _, _ -> attempts += 1 if (attempts < 3) { Result.failure(IOException("timeout")) diff --git a/app/src/test/java/com/notcvnt/rknhardering/probe/XrayApiClientTest.kt b/app/src/test/java/com/notcvnt/rknhardering/probe/XrayApiClientTest.kt new file mode 100644 index 0000000..63a1bb5 --- /dev/null +++ b/app/src/test/java/com/notcvnt/rknhardering/probe/XrayApiClientTest.kt @@ -0,0 +1,58 @@ +package com.notcvnt.rknhardering.probe + +import com.notcvnt.rknhardering.ScanCancellationSignal +import com.notcvnt.rknhardering.ScanExecutionContext +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.runBlocking +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Test +import java.net.InetAddress +import java.net.ServerSocket +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import kotlin.concurrent.thread + +class XrayApiClientTest { + + @Test + fun `cancelled grpc call shuts down channel immediately`() { + ServerSocket(0, 50, InetAddress.getByName("127.0.0.1")).use { server -> + val accepted = CountDownLatch(1) + val releaseServer = CountDownLatch(1) + val executionContext = ScanExecutionContext(cancellationSignal = ScanCancellationSignal()) + + val serverWorker = thread(start = true, isDaemon = true) { + try { + server.accept().use { + accepted.countDown() + releaseServer.await(2, TimeUnit.SECONDS) + } + } catch (_: Exception) { + } + } + + var failure: Throwable? = null + val clientWorker = thread(start = true) { + failure = runCatching { + runBlocking { + XrayApiClient("127.0.0.1").listOutbounds( + port = server.localPort, + deadlineMs = 30_000, + executionContext = executionContext, + ) + } + }.exceptionOrNull() + } + + assertTrue(accepted.await(2, TimeUnit.SECONDS)) + executionContext.cancellationSignal.cancel() + clientWorker.join(2_000) + releaseServer.countDown() + serverWorker.join(2_000) + + assertFalse(clientWorker.isAlive) + assertTrue(failure is CancellationException) + } + } +}