diff --git a/android-test/build.gradle.kts b/android-test/build.gradle.kts index 60e705263d64..7258aef6819f 100644 --- a/android-test/build.gradle.kts +++ b/android-test/build.gradle.kts @@ -66,6 +66,7 @@ dependencies { "friendsImplementation"(projects.okhttpDnsoverhttps) testImplementation(projects.okhttp) + testImplementation(projects.okhttpCoroutines) testImplementation(libs.junit) testImplementation(libs.junit.ktx) testImplementation(libs.assertk) diff --git a/android-test/src/androidTest/java/okhttp/android/test/AlwaysHttps.kt b/android-test/src/androidTest/java/okhttp/android/test/AlwaysHttps.kt new file mode 100644 index 000000000000..063f0c1c597f --- /dev/null +++ b/android-test/src/androidTest/java/okhttp/android/test/AlwaysHttps.kt @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2025 Block, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp.android.test + +import android.os.Build +import android.security.NetworkSecurityPolicy +import okhttp3.Call +import okhttp3.Request + +class AlwaysHttps( + policy: Policy, +) : Call.Decorator { + val hostPolicy: HostPolicy = policy.hostPolicy + + override fun newCall(chain: Call.Chain): Call { + val request = chain.request + + val updatedRequest = + if (request.url.scheme == "http" && !hostPolicy.isCleartextTrafficPermitted(request)) { + request + .newBuilder() + .url( + request.url + .newBuilder() + .scheme("https") + .build(), + ).build() + } else { + request + } + + return chain.proceed(updatedRequest) + } + + fun interface HostPolicy { + fun isCleartextTrafficPermitted(request: Request): Boolean + } + + enum class Policy { + Always { + override val hostPolicy: HostPolicy + get() = HostPolicy { false } + }, + Manifest { + override val hostPolicy: HostPolicy + get() = + if (Build.VERSION.SDK_INT > Build.VERSION_CODES.M) { + val networkSecurityPolicy = NetworkSecurityPolicy.getInstance() + + if (Build.VERSION.SDK_INT > Build.VERSION_CODES.N) { + HostPolicy { networkSecurityPolicy.isCleartextTrafficPermitted(it.url.host) } + } else { + HostPolicy { networkSecurityPolicy.isCleartextTrafficPermitted } + } + } else { + HostPolicy { true } + } + }, ; + + abstract val hostPolicy: HostPolicy + } +} diff --git a/android-test/src/androidTest/java/okhttp/android/test/AndroidCallDecoratorTest.kt b/android-test/src/androidTest/java/okhttp/android/test/AndroidCallDecoratorTest.kt new file mode 100644 index 000000000000..9f062fc01354 --- /dev/null +++ b/android-test/src/androidTest/java/okhttp/android/test/AndroidCallDecoratorTest.kt @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2025 Block, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp.android.test + +import java.util.logging.Logger +import mockwebserver3.MockResponse +import mockwebserver3.MockWebServer +import mockwebserver3.junit5.StartStop +import okhttp.android.test.AlwaysHttps.Policy +import okhttp3.OkHttpClient +import okhttp3.OkHttpClientTestRule +import okhttp3.Request +import okhttp3.tls.internal.TlsUtil.localhost +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension + +@Tag("Slow") +class AndroidCallDecoratorTest { + @Suppress("RedundantVisibilityModifier") + @JvmField + @RegisterExtension + public val clientTestRule = + OkHttpClientTestRule().apply { + logger = Logger.getLogger(AndroidCallDecoratorTest::class.java.name) + } + + private var client: OkHttpClient = + clientTestRule + .newClientBuilder() + .addCallDecorator(AlwaysHttps(Policy.Always)) + .addCallDecorator(OffMainThread) + .build() + + @StartStop + private val server = MockWebServer() + + private val handshakeCertificates = localhost() + + @Test + fun testSecureRequest() { + enableTls() + + server.enqueue(MockResponse()) + + val request = Request.Builder().url(server.url("/")).build() + + client.newCall(request).execute().use { + assertEquals(200, it.code) + } + } + + @Test + fun testInsecureRequestChangedToSecure() { + enableTls() + + server.enqueue(MockResponse()) + + val request = + Request + .Builder() + .url( + server + .url("/") + .newBuilder() + .scheme("http") + .build(), + ).build() + + client.newCall(request).execute().use { + assertEquals(200, it.code) + assertEquals("https", it.request.url.scheme) + } + } + + private fun enableTls() { + client = + client + .newBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), + handshakeCertificates.trustManager, + ).build() + server.useHttps(handshakeCertificates.sslSocketFactory()) + } +} diff --git a/android-test/src/androidTest/java/okhttp/android/test/OffMainThread.kt b/android-test/src/androidTest/java/okhttp/android/test/OffMainThread.kt new file mode 100644 index 000000000000..4fff9c5bb48b --- /dev/null +++ b/android-test/src/androidTest/java/okhttp/android/test/OffMainThread.kt @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2025 Block, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp.android.test + +import android.os.Looper +import okhttp3.Call +import okhttp3.Response + +/** + * Sample of a Decorator that will fail any call on the Android Main thread. + */ +object OffMainThread : Call.Decorator { + override fun newCall(chain: Call.Chain): Call = StrictModeCall(chain.proceed(chain.request)) + + private class StrictModeCall( + private val delegate: Call, + ) : Call by delegate { + override fun execute(): Response { + if (Looper.getMainLooper() === Looper.myLooper()) { + throw IllegalStateException("Network on main thread") + } + + return delegate.execute() + } + + override fun clone(): Call = StrictModeCall(delegate.clone()) + } +} diff --git a/okhttp/api/android/okhttp.api b/okhttp/api/android/okhttp.api index 0cc0e076195d..78f4016d1a7a 100644 --- a/okhttp/api/android/okhttp.api +++ b/okhttp/api/android/okhttp.api @@ -129,6 +129,16 @@ public abstract interface class okhttp3/Call : java/lang/Cloneable { public abstract fun timeout ()Lokio/Timeout; } +public abstract interface class okhttp3/Call$Chain { + public abstract fun getClient ()Lokhttp3/OkHttpClient; + public abstract fun getRequest ()Lokhttp3/Request; + public abstract fun proceed (Lokhttp3/Request;)Lokhttp3/Call; +} + +public abstract interface class okhttp3/Call$Decorator { + public abstract fun newCall (Lokhttp3/Call$Chain;)Lokhttp3/Call; +} + public abstract interface class okhttp3/Call$Factory { public abstract fun newCall (Lokhttp3/Request;)Lokhttp3/Call; } @@ -905,6 +915,7 @@ public class okhttp3/OkHttpClient : okhttp3/Call$Factory, okhttp3/WebSocket$Fact public final fun fastFallback ()Z public final fun followRedirects ()Z public final fun followSslRedirects ()Z + public final fun getCallDecorators ()Ljava/util/List; public final fun hostnameVerifier ()Ljavax/net/ssl/HostnameVerifier; public final fun interceptors ()Ljava/util/List; public final fun minWebSocketMessageToCompress ()J @@ -930,6 +941,7 @@ public final class okhttp3/OkHttpClient$Builder { public final fun -addInterceptor (Lkotlin/jvm/functions/Function1;)Lokhttp3/OkHttpClient$Builder; public final fun -addNetworkInterceptor (Lkotlin/jvm/functions/Function1;)Lokhttp3/OkHttpClient$Builder; public fun ()V + public final fun addCallDecorator (Lokhttp3/Call$Decorator;)Lokhttp3/OkHttpClient$Builder; public final fun addInterceptor (Lokhttp3/Interceptor;)Lokhttp3/OkHttpClient$Builder; public final fun addNetworkInterceptor (Lokhttp3/Interceptor;)Lokhttp3/OkHttpClient$Builder; public final fun authenticator (Lokhttp3/Authenticator;)Lokhttp3/OkHttpClient$Builder; diff --git a/okhttp/api/jvm/okhttp.api b/okhttp/api/jvm/okhttp.api index d275ffa0f73b..498335764aad 100644 --- a/okhttp/api/jvm/okhttp.api +++ b/okhttp/api/jvm/okhttp.api @@ -129,6 +129,16 @@ public abstract interface class okhttp3/Call : java/lang/Cloneable { public abstract fun timeout ()Lokio/Timeout; } +public abstract interface class okhttp3/Call$Chain { + public abstract fun getClient ()Lokhttp3/OkHttpClient; + public abstract fun getRequest ()Lokhttp3/Request; + public abstract fun proceed (Lokhttp3/Request;)Lokhttp3/Call; +} + +public abstract interface class okhttp3/Call$Decorator { + public abstract fun newCall (Lokhttp3/Call$Chain;)Lokhttp3/Call; +} + public abstract interface class okhttp3/Call$Factory { public abstract fun newCall (Lokhttp3/Request;)Lokhttp3/Call; } @@ -904,6 +914,7 @@ public class okhttp3/OkHttpClient : okhttp3/Call$Factory, okhttp3/WebSocket$Fact public final fun fastFallback ()Z public final fun followRedirects ()Z public final fun followSslRedirects ()Z + public final fun getCallDecorators ()Ljava/util/List; public final fun hostnameVerifier ()Ljavax/net/ssl/HostnameVerifier; public final fun interceptors ()Ljava/util/List; public final fun minWebSocketMessageToCompress ()J @@ -929,6 +940,7 @@ public final class okhttp3/OkHttpClient$Builder { public final fun -addInterceptor (Lkotlin/jvm/functions/Function1;)Lokhttp3/OkHttpClient$Builder; public final fun -addNetworkInterceptor (Lkotlin/jvm/functions/Function1;)Lokhttp3/OkHttpClient$Builder; public fun ()V + public final fun addCallDecorator (Lokhttp3/Call$Decorator;)Lokhttp3/OkHttpClient$Builder; public final fun addInterceptor (Lokhttp3/Interceptor;)Lokhttp3/OkHttpClient$Builder; public final fun addNetworkInterceptor (Lokhttp3/Interceptor;)Lokhttp3/OkHttpClient$Builder; public final fun authenticator (Lokhttp3/Authenticator;)Lokhttp3/OkHttpClient$Builder; diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/Call.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/Call.kt index fdd3d3da294e..371bd4c715e2 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/Call.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/Call.kt @@ -96,4 +96,33 @@ interface Call : Cloneable { fun interface Factory { fun newCall(request: Request): Call } + + /** + * The equivalent of an Interceptor for [Call.Factory], but supported directly within [OkHttpClient] newCall. + * + * An [Interceptor] forms a chain as part of execution of a Call. Instead, Call.Decorator intercepts + * [Call.Factory.newCall] with similar flexibility to Application [OkHttpClient.interceptors]. + * + * That is, it may do any of + * - Modify the request such as adding Tracing Context + * - Wrap the [Call] returned + * - Return some [Call] implementation that will immediately fail avoiding network calls based on network or + * authentication state. + * - Redirect the [Call], such as using an alternative [Call.Factory]. + * - Defer execution, something not safe in an Interceptor. + * + * It should not throw an exception, instead it should return a Call that will fail on [Call.execute]. + * + * A Decorator that changes the OkHttpClient should typically retain later decorators in the new client. + */ + fun interface Decorator { + fun newCall(chain: Chain): Call + } + + interface Chain { + val client: OkHttpClient + val request: Request + + fun proceed(request: Request): Call + } } diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/OkHttpClient.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/OkHttpClient.kt index 6ab5bcf7d374..8af05700e36d 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/OkHttpClient.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/OkHttpClient.kt @@ -145,6 +145,14 @@ open class OkHttpClient internal constructor( val interceptors: List = builder.interceptors.toImmutableList() + /** + * Returns an immutable list of Call decorators that have a chance to return a different, likely + * decorating, implementation of Call. This allows functionality such as fail fast without normal Call + * execution based on network conditions, or setting Tracing context on the calling thread. + */ + val callDecorators: List = + builder.callDecorators.toImmutableList() + /** * Returns an immutable list of interceptors that observe a single network request and response. * These interceptors must call [Interceptor.Chain.proceed] exactly once: it is an error for @@ -265,6 +273,27 @@ open class OkHttpClient internal constructor( internal val routeDatabase: RouteDatabase = builder.routeDatabase ?: RouteDatabase() internal val taskRunner: TaskRunner = builder.taskRunner ?: TaskRunner.INSTANCE + private val decoratedCallFactory = + callDecorators.foldRight( + Call.Factory { request -> + RealCall(client = this, originalRequest = request, forWebSocket = false) + }, + ) { callDecorator, next -> + Call.Factory { request -> + callDecorator.newCall( + object : Call.Chain { + override val client: OkHttpClient + get() = this@OkHttpClient + + override val request: Request + get() = request + + override fun proceed(request: Request): Call = next.newCall(request) + }, + ) + } + } + @get:JvmName("connectionPool") val connectionPool: ConnectionPool = builder.connectionPool ?: ConnectionPool().also { @@ -350,7 +379,7 @@ open class OkHttpClient internal constructor( } /** Prepares the [request] to be executed at some point in the future. */ - override fun newCall(request: Request): Call = RealCall(this, request, forWebSocket = false) + override fun newCall(request: Request): Call = decoratedCallFactory.newCall(request) /** Uses [request] to connect a new web socket. */ override fun newWebSocket( @@ -587,6 +616,7 @@ open class OkHttpClient internal constructor( internal var dispatcher: Dispatcher = Dispatcher() internal var connectionPool: ConnectionPool? = null internal val interceptors: MutableList = mutableListOf() + internal val callDecorators: MutableList = mutableListOf() internal val networkInterceptors: MutableList = mutableListOf() internal var eventListenerFactory: EventListener.Factory = EventListener.NONE.asFactory() internal var retryOnConnectionFailure = true @@ -622,6 +652,7 @@ open class OkHttpClient internal constructor( this.dispatcher = okHttpClient.dispatcher this.connectionPool = okHttpClient.connectionPool this.interceptors += okHttpClient.interceptors + this.callDecorators += okHttpClient.callDecorators this.networkInterceptors += okHttpClient.networkInterceptors this.eventListenerFactory = okHttpClient.eventListenerFactory this.retryOnConnectionFailure = okHttpClient.retryOnConnectionFailure @@ -726,6 +757,11 @@ open class OkHttpClient internal constructor( this.eventListenerFactory = eventListenerFactory } + fun addCallDecorator(decorator: Call.Decorator) = + apply { + callDecorators += decorator + } + /** * Configure this client to retry or not when a connectivity problem is encountered. By default, * this client silently recovers from the following problems: diff --git a/okhttp/src/jvmTest/kotlin/okhttp3/CallDecoratorTest.kt b/okhttp/src/jvmTest/kotlin/okhttp3/CallDecoratorTest.kt new file mode 100644 index 000000000000..9f4c929f8103 --- /dev/null +++ b/okhttp/src/jvmTest/kotlin/okhttp3/CallDecoratorTest.kt @@ -0,0 +1,204 @@ +/* + * Copyright (C) 2025 Block, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3 + +import java.io.IOException +import java.util.logging.Logger +import mockwebserver3.MockResponse +import mockwebserver3.MockWebServer +import mockwebserver3.junit5.StartStop +import okhttp3.internal.connection.RealCall +import okhttp3.tls.internal.TlsUtil.localhost +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension + +@Tag("Slow") +class CallDecoratorTest { + @Suppress("RedundantVisibilityModifier") + @JvmField + @RegisterExtension + public val clientTestRule = + OkHttpClientTestRule().apply { + logger = Logger.getLogger(CallDecoratorTest::class.java.name) + } + + @StartStop + private val server = MockWebServer() + + private val handshakeCertificates = localhost() + + @Test + fun testSecureRequest() { + server.enqueue(MockResponse()) + + val request = Request.Builder().url(server.url("/")).build() + + val client: OkHttpClient = + clientTestRule + .newClientBuilder() + .enableTls() + .addCallDecorator(AlwaysHttps) + .build() + + client.newCall(request).execute().use { + assertEquals(200, it.code) + } + } + + @Test + fun testInsecureRequestChangedToSecure() { + server.enqueue(MockResponse()) + + val request = + Request + .Builder() + .url( + server + .url("/") + .newBuilder() + .scheme("http") + .build(), + ).build() + + val client: OkHttpClient = + clientTestRule + .newClientBuilder() + .enableTls() + .addCallDecorator(AlwaysHttps) + .build() + + client.newCall(request).execute().use { + assertEquals(200, it.code) + assertEquals("https", it.request.url.scheme) + } + } + + class WrappedCall( + delegate: Call, + ) : Call by delegate + + @Test + fun testWrappedCallIsObserved() { + server.enqueue(MockResponse()) + + val client: OkHttpClient = + clientTestRule + .newClientBuilder() + .addCallDecorator { chain, request -> + // First Call.Decorator will see the result of later decorators + chain.newCall(request).also { + if (it !is WrappedCall) { + throw IOException("expecting wrapped call") + } + if (it.request().tag() != "wrapped") { + throw IOException("expecting tag1") + } + } + }.addCallDecorator { chain, request -> + // Wrap here + val updatedRequest = request.newBuilder().tag("wrapped").build() + WrappedCall(chain.newCall(updatedRequest)) + }.addCallDecorator { chain, request -> + // Updated requests are seen + if (request.tag() != "wrapped") { + throw IOException("expecting tag2") + } + chain.newCall(request).also { + // But Call is RealCall + if (it !is RealCall) { + throw IOException("expecting RealCall") + } + } + }.addInterceptor { chain -> + // Updated requests are seen in interceptors + if (chain.request().tag() != "wrapped") { + throw IOException("expecting tag3") + } + chain.proceed(chain.request()) + }.addNetworkInterceptor { chain -> + // and network interceptors + if (chain.request().tag() != "wrapped") { + throw IOException("expecting tag4") + } + chain.proceed(chain.request()) + }.build() + + val originalRequest = Request.Builder().url(server.url("/")).build() + client.newCall(originalRequest).execute().use { + assertEquals(200, it.code) + } + } + + @Test + fun testCanShortCircuit() { + server.enqueue(MockResponse()) + + val request = Request.Builder().url(server.url("/")).build() + + val client: OkHttpClient = + clientTestRule + .newClientBuilder() + .build() + + val redirectingClient: OkHttpClient = + clientTestRule + .newClientBuilder() + .addCallDecorator { _, request -> + // Use the other client + client.newCall(request) + }.addInterceptor { + // Fail if we get here + throw IOException("You shall not pass") + }.build() + + redirectingClient.newCall(request).execute().use { + assertEquals(200, it.code) + } + } + + private fun OkHttpClient.Builder.enableTls(): OkHttpClient.Builder { + server.useHttps(handshakeCertificates.sslSocketFactory()) + return sslSocketFactory( + handshakeCertificates.sslSocketFactory(), + handshakeCertificates.trustManager, + ) + } +} + +private object AlwaysHttps : Call.Decorator { + override fun newCall( + chain: Call.Factory, + request: Request, + ): Call { + val updatedRequest = + if (request.url.scheme == "http") { + request + .newBuilder() + .url( + request.url + .newBuilder() + .scheme("https") + .build(), + ).build() + } else { + request + } + + return chain.newCall(updatedRequest) + } +}