diff --git a/api/src/main/java/io/grpc/ChannelConfigurer.java b/api/src/main/java/io/grpc/ChannelConfigurer.java new file mode 100644 index 00000000000..78288a87ffe --- /dev/null +++ b/api/src/main/java/io/grpc/ChannelConfigurer.java @@ -0,0 +1,73 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + + + +/** + * A configurer for child channels created by gRPC's internal infrastructure. + * + *

This interface allows users to inject configuration (such as credentials, interceptors, + * or flow control settings) into channels created automatically by gRPC for control plane + * operations. Common use cases include: + *

+ * + *

Usage Example: + *

{@code
+ * // 1. Define the configurer
+ * ChannelConfigurer configurer = builder -> {
+ *   builder.maxInboundMessageSize(4 * 1024 * 1024);
+ * };
+ *
+ * // 2. Apply to parent channel - automatically used for ALL child channels
+ * ManagedChannel channel = ManagedChannelBuilder
+ *     .forTarget("xds:///my-service")
+ *     .childChannelConfigurer(configurer)
+ *     .build();
+ * }
+ * + *

Implementations must be thread-safe as the configure methods may be invoked concurrently + * by multiple internal components. + * + * @since 1.81.0 + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/12574") +public interface ChannelConfigurer { + + /** + * Configures a builder for a new child channel. + * + *

This method is invoked synchronously during the creation of the child channel, + * before {@link ManagedChannelBuilder#build()} is called. + * + * @param builder the mutable channel builder for the new child channel + */ + default void configureChannelBuilder(ManagedChannelBuilder builder) {} + + /** + * Configures a builder for a new child server. + * + *

This method is invoked synchronously during the creation of the child server, + * before {@link ServerBuilder#build()} is called. + * + * @param builder the mutable server builder for the new child server + */ + default void configureServerBuilder(ServerBuilder builder) {} +} diff --git a/api/src/main/java/io/grpc/ForwardingChannelBuilder.java b/api/src/main/java/io/grpc/ForwardingChannelBuilder.java index 1202582421a..33276d2fdd3 100644 --- a/api/src/main/java/io/grpc/ForwardingChannelBuilder.java +++ b/api/src/main/java/io/grpc/ForwardingChannelBuilder.java @@ -242,6 +242,13 @@ public T disableServiceConfigLookUp() { return thisT(); } + + @Override + public T childChannelConfigurer(ChannelConfigurer channelConfigurer) { + delegate().childChannelConfigurer(channelConfigurer); + return thisT(); + } + /** * Returns the correctly typed version of the builder. */ diff --git a/api/src/main/java/io/grpc/ForwardingChannelBuilder2.java b/api/src/main/java/io/grpc/ForwardingChannelBuilder2.java index 78fe730d91a..0fd3f1fb209 100644 --- a/api/src/main/java/io/grpc/ForwardingChannelBuilder2.java +++ b/api/src/main/java/io/grpc/ForwardingChannelBuilder2.java @@ -258,7 +258,7 @@ public T disableServiceConfigLookUp() { } @Override - protected T addMetricSink(MetricSink metricSink) { + public T addMetricSink(MetricSink metricSink) { delegate().addMetricSink(metricSink); return thisT(); } @@ -269,6 +269,13 @@ public T setNameResolverArg(NameResolver.Args.Key key, X value) { return thisT(); } + + @Override + public T childChannelConfigurer(ChannelConfigurer channelConfigurer) { + delegate().childChannelConfigurer(channelConfigurer); + return thisT(); + } + /** * Returns the {@link ManagedChannel} built by the delegate by default. Overriding method can * return different value. diff --git a/api/src/main/java/io/grpc/ForwardingServerBuilder.java b/api/src/main/java/io/grpc/ForwardingServerBuilder.java index 9cef7cfa331..a0478ddad49 100644 --- a/api/src/main/java/io/grpc/ForwardingServerBuilder.java +++ b/api/src/main/java/io/grpc/ForwardingServerBuilder.java @@ -192,6 +192,12 @@ public T setBinaryLog(BinaryLog binaryLog) { return thisT(); } + @Override + public T childChannelConfigurer(ChannelConfigurer channelConfigurer) { + delegate().childChannelConfigurer(channelConfigurer); + return thisT(); + } + /** * Returns the {@link Server} built by the delegate by default. Overriding method can return * different value. diff --git a/api/src/main/java/io/grpc/ManagedChannelBuilder.java b/api/src/main/java/io/grpc/ManagedChannelBuilder.java index 3f370ab3003..bf00854f30a 100644 --- a/api/src/main/java/io/grpc/ManagedChannelBuilder.java +++ b/api/src/main/java/io/grpc/ManagedChannelBuilder.java @@ -168,9 +168,7 @@ protected T interceptWithTarget(InterceptorFactory factory) { throw new UnsupportedOperationException(); } - /** Internal-only. */ - @Internal - protected interface InterceptorFactory { + public interface InterceptorFactory { ClientInterceptor newInterceptor(String target); } @@ -638,8 +636,7 @@ public T disableServiceConfigLookUp() { * @return this * @since 1.64.0 */ - @Internal - protected T addMetricSink(MetricSink metricSink) { + public T addMetricSink(MetricSink metricSink) { throw new UnsupportedOperationException(); } @@ -661,6 +658,22 @@ public T setNameResolverArg(NameResolver.Args.Key key, X value) { throw new UnsupportedOperationException(); } + + /** + * Sets a configurer that will be applied to all internal child channels created by this channel. + * + *

This allows injecting configuration (like credentials, interceptors, or flow control) + * into auxiliary channels created by gRPC infrastructure, such as xDS control plane connections. + * + * @param channelConfigurer the configurer to apply. + * @return this + * @since 1.81.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12574") + public T childChannelConfigurer(ChannelConfigurer channelConfigurer) { + throw new UnsupportedOperationException("Not implemented"); + } + /** * Builds a channel using the given parameters. * diff --git a/api/src/main/java/io/grpc/MetricRecorder.java b/api/src/main/java/io/grpc/MetricRecorder.java index 897c28011cd..1f765ddc115 100644 --- a/api/src/main/java/io/grpc/MetricRecorder.java +++ b/api/src/main/java/io/grpc/MetricRecorder.java @@ -26,6 +26,15 @@ */ @Internal public interface MetricRecorder { + + /** + * Returns a {@link MetricRecorder} that performs no operations. + * The returned instance ignores all calls and skips all validation checks. + */ + static MetricRecorder noOp() { + return NoOpMetricRecorder.INSTANCE; + } + /** * Adds a value for a double-precision counter metric instrument. * @@ -176,4 +185,47 @@ interface Registration extends AutoCloseable { @Override void close(); } + + /** + * No-Op implementation of MetricRecorder. + * Overrides all default methods to skip validation checks for maximum performance. + */ + final class NoOpMetricRecorder implements MetricRecorder { + private static final NoOpMetricRecorder INSTANCE = new NoOpMetricRecorder(); + + @Override + public void addDoubleCounter(DoubleCounterMetricInstrument metricInstrument, double value, + List requiredLabelValues, + List optionalLabelValues) { + } + + @Override + public void addLongCounter(LongCounterMetricInstrument metricInstrument, long value, + List requiredLabelValues, List optionalLabelValues) { + } + + @Override + public void addLongUpDownCounter(LongUpDownCounterMetricInstrument metricInstrument, long value, + List requiredLabelValues, + List optionalLabelValues) { + } + + @Override + public void recordDoubleHistogram(DoubleHistogramMetricInstrument metricInstrument, + double value, List requiredLabelValues, + List optionalLabelValues) { + } + + @Override + public void recordLongHistogram(LongHistogramMetricInstrument metricInstrument, long value, + List requiredLabelValues, + List optionalLabelValues) { + } + + @Override + public Registration registerBatchCallback(BatchCallback callback, + CallbackMetricInstrument... metricInstruments) { + return () -> { }; + } + } } diff --git a/api/src/main/java/io/grpc/NameResolver.java b/api/src/main/java/io/grpc/NameResolver.java index 53dbc5d6888..0d70c20dc0b 100644 --- a/api/src/main/java/io/grpc/NameResolver.java +++ b/api/src/main/java/io/grpc/NameResolver.java @@ -358,6 +358,7 @@ public static final class Args { @Nullable private final MetricRecorder metricRecorder; @Nullable private final NameResolverRegistry nameResolverRegistry; @Nullable private final IdentityHashMap, Object> customArgs; + @Nullable private final ChannelConfigurer channelConfigurer; private Args(Builder builder) { this.defaultPort = checkNotNull(builder.defaultPort, "defaultPort not set"); @@ -372,6 +373,7 @@ private Args(Builder builder) { this.metricRecorder = builder.metricRecorder; this.nameResolverRegistry = builder.nameResolverRegistry; this.customArgs = cloneCustomArgs(builder.customArgs); + this.channelConfigurer = builder.channelConfigurer; } /** @@ -470,6 +472,17 @@ public ChannelLogger getChannelLogger() { return channelLogger; } + /** + * Returns the configurer for child channels. + * + * @since 1.81.0 + */ + @Nullable + @Internal + public ChannelConfigurer getChildChannelConfigurer() { + return channelConfigurer; + } + /** * Returns the Executor on which this resolver should execute long-running or I/O bound work. * Null if no Executor was set. @@ -579,6 +592,7 @@ public static final class Builder { private MetricRecorder metricRecorder; private NameResolverRegistry nameResolverRegistry; private IdentityHashMap, Object> customArgs; + private ChannelConfigurer channelConfigurer = new ChannelConfigurer() {}; Builder() { } @@ -694,6 +708,16 @@ public Builder setNameResolverRegistry(NameResolverRegistry registry) { return this; } + /** + * See {@link Args#getChildChannelConfigurer()}. This is an optional field. + * + * @since 1.81.0 + */ + public Builder setChildChannelConfigurer(ChannelConfigurer channelConfigurer) { + this.channelConfigurer = channelConfigurer; + return this; + } + /** * Builds an {@link Args}. * diff --git a/api/src/main/java/io/grpc/Server.java b/api/src/main/java/io/grpc/Server.java index 97ea06a81c2..d744752ecc5 100644 --- a/api/src/main/java/io/grpc/Server.java +++ b/api/src/main/java/io/grpc/Server.java @@ -178,4 +178,5 @@ public List getMutableServices() { * @since 1.0.0 */ public abstract void awaitTermination() throws InterruptedException; + } diff --git a/api/src/main/java/io/grpc/ServerBuilder.java b/api/src/main/java/io/grpc/ServerBuilder.java index cd1cddbb93f..54f69d32e0b 100644 --- a/api/src/main/java/io/grpc/ServerBuilder.java +++ b/api/src/main/java/io/grpc/ServerBuilder.java @@ -424,6 +424,24 @@ public T setBinaryLog(BinaryLog binaryLog) { throw new UnsupportedOperationException(); } + + /** + * Sets a configurer that will be applied to all internal child channels created by this server. + * + *

This allows injecting configuration (like credentials, interceptors, or flow control) + * into auxiliary channels created by gRPC infrastructure, such as xDS control plane connections + * or OOB load balancing channels. + * + * @param channelConfigurer the configurer to apply. + * @return this + * @since 1.81.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12574") + public T childChannelConfigurer(ChannelConfigurer channelConfigurer) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** * Builds a server using the given parameters. * diff --git a/api/src/test/java/io/grpc/ChannelConfigurerTest.java b/api/src/test/java/io/grpc/ChannelConfigurerTest.java new file mode 100644 index 00000000000..f6e56c81a64 --- /dev/null +++ b/api/src/test/java/io/grpc/ChannelConfigurerTest.java @@ -0,0 +1,41 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyNoInteractions; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ChannelConfigurerTest { + + @Test + public void defaultMethods_doNothing() { + ChannelConfigurer configurer = new ChannelConfigurer() {}; + + ManagedChannelBuilder mockChannelBuilder = mock(ManagedChannelBuilder.class); + configurer.configureChannelBuilder(mockChannelBuilder); + verifyNoInteractions(mockChannelBuilder); + + ServerBuilder mockServerBuilder = mock(ServerBuilder.class); + configurer.configureServerBuilder(mockServerBuilder); + verifyNoInteractions(mockServerBuilder); + } +} diff --git a/api/src/test/java/io/grpc/NameResolverTest.java b/api/src/test/java/io/grpc/NameResolverTest.java index 82abe5c7505..e3864a7665a 100644 --- a/api/src/test/java/io/grpc/NameResolverTest.java +++ b/api/src/test/java/io/grpc/NameResolverTest.java @@ -105,6 +105,7 @@ public void args() { } private NameResolver.Args createArgs() { + ChannelConfigurer channelConfigurer = mock(ChannelConfigurer.class); return NameResolver.Args.newBuilder() .setDefaultPort(defaultPort) .setProxyDetector(proxyDetector) @@ -116,9 +117,39 @@ private NameResolver.Args createArgs() { .setOverrideAuthority(overrideAuthority) .setMetricRecorder(metricRecorder) .setArg(FOO_ARG_KEY, customArgValue) + .setChildChannelConfigurer(channelConfigurer) .build(); } + @Test + public void args_childChannelConfigurer() { + ChannelConfigurer channelConfigurer = mock(ChannelConfigurer.class); + + SynchronizationContext realSyncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + + NameResolver.Args args = NameResolver.Args.newBuilder() + .setDefaultPort(8080) + .setProxyDetector(mock(ProxyDetector.class)) + .setSynchronizationContext(realSyncContext) + .setServiceConfigParser(mock(NameResolver.ServiceConfigParser.class)) + .setChannelLogger(mock(ChannelLogger.class)) + .setChildChannelConfigurer(channelConfigurer) + .build(); + + assertThat(args.getChildChannelConfigurer()).isSameInstanceAs(channelConfigurer); + + // Validate configurer accepts builders + ManagedChannelBuilder mockBuilder = mock(ManagedChannelBuilder.class); + args.getChildChannelConfigurer().configureChannelBuilder(mockBuilder); + verify(channelConfigurer).configureChannelBuilder(mockBuilder); + } + @Test @SuppressWarnings("deprecation") public void startOnOldListener_wrapperListener2UsedToStart() { diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index e423220e3ad..e5e4332b70f 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -37,6 +37,7 @@ import io.grpc.CallCredentials; import io.grpc.CallOptions; import io.grpc.Channel; +import io.grpc.ChannelConfigurer; import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; @@ -155,6 +156,14 @@ public Result selectConfig(PickSubchannelArgs args) { private static final LoadBalancer.PickDetailsConsumer NOOP_PICK_DETAILS_CONSUMER = new LoadBalancer.PickDetailsConsumer() {}; + /** + * Retrieves the user-provided configuration function for internal child channels. + * + *

This is intended for use by gRPC internal components + * that are responsible for creating auxiliary {@code ManagedChannel} instances. + */ + private ChannelConfigurer channelConfigurer = new ChannelConfigurer() {}; + private final InternalLogId logId; private final String target; @Nullable @@ -545,6 +554,9 @@ ClientStream newSubstream( Supplier stopwatchSupplier, List interceptors, final TimeProvider timeProvider) { + if (builder.channelConfigurer != null) { + this.channelConfigurer = builder.channelConfigurer; + } this.target = checkNotNull(builder.target, "target"); this.logId = InternalLogId.allocate("Channel", target); this.timeProvider = checkNotNull(timeProvider, "timeProvider"); @@ -589,7 +601,8 @@ ClientStream newSubstream( .setOffloadExecutor(this.offloadExecutorHolder) .setOverrideAuthority(this.authorityOverride) .setMetricRecorder(this.metricRecorder) - .setNameResolverRegistry(builder.nameResolverRegistry); + .setNameResolverRegistry(builder.nameResolverRegistry) + .setChildChannelConfigurer(this.channelConfigurer); builder.copyAllNameResolverCustomArgsTo(nameResolverArgsBuilder); this.nameResolverArgs = nameResolverArgsBuilder.build(); this.nameResolver = getNameResolver( @@ -1486,6 +1499,12 @@ protected ManagedChannelBuilder delegate() { ResolvingOobChannelBuilder builder = new ResolvingOobChannelBuilder(); + // Note that we follow the global configurator pattern and try to fuse the configurations as + // soon as the builder gets created + if (channelConfigurer != null) { + channelConfigurer.configureChannelBuilder(builder); + } + return builder // TODO(zdapeng): executors should not outlive the parent channel. .executor(executor) diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java index 128c929ec0e..2d09d4b495e 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java @@ -29,6 +29,7 @@ import io.grpc.CallCredentials; import io.grpc.CallOptions; import io.grpc.Channel; +import io.grpc.ChannelConfigurer; import io.grpc.ChannelCredentials; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; @@ -127,6 +128,16 @@ public static ManagedChannelBuilder forTarget(String target) { private static final Method GET_CLIENT_INTERCEPTOR_METHOD; + ChannelConfigurer channelConfigurer = new ChannelConfigurer() {}; + + @Override + public ManagedChannelImplBuilder childChannelConfigurer( + ChannelConfigurer channelConfigurer) { + this.channelConfigurer = checkNotNull(channelConfigurer, + "childChannelConfigurer"); + return this; + } + static { Method getClientInterceptorMethod = null; try { @@ -403,7 +414,7 @@ public ManagedChannelImplBuilder intercept(ClientInterceptor... interceptors) { } @Override - protected ManagedChannelImplBuilder interceptWithTarget(InterceptorFactory factory) { + public ManagedChannelImplBuilder interceptWithTarget(InterceptorFactory factory) { // Add a placeholder instance to the interceptor list, and replace it with a real instance // during build(). this.interceptors.add(new InterceptorFactoryWrapper(factory)); @@ -712,7 +723,7 @@ public ManagedChannelImplBuilder enableCheckAuthority() { } @Override - protected ManagedChannelImplBuilder addMetricSink(MetricSink metricSink) { + public ManagedChannelImplBuilder addMetricSink(MetricSink metricSink) { metricSinks.add(checkNotNull(metricSink, "metric sink")); return this; } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java index b0939239477..5f935deddbe 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java @@ -26,6 +26,7 @@ import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -34,6 +35,7 @@ import com.google.common.util.concurrent.MoreExecutors; import io.grpc.CallOptions; import io.grpc.Channel; +import io.grpc.ChannelConfigurer; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.CompressorRegistry; @@ -42,6 +44,7 @@ import io.grpc.InternalConfigurator; import io.grpc.InternalConfiguratorRegistry; import io.grpc.InternalFeatureFlags; +import io.grpc.InternalManagedChannelBuilder; import io.grpc.InternalManagedChannelBuilder.InternalInterceptorFactory; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; @@ -74,6 +77,7 @@ import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameter; import org.junit.runners.Parameterized.Parameters; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -780,6 +784,80 @@ public void setNameResolverExtArgs() { assertThat(builder.nameResolverCustomArgs.get(testKey)).isEqualTo(42); } + @Test + public void childChannelConfigurer_setsField() { + ChannelConfigurer configurer = mock(ChannelConfigurer.class); + assertSame(builder, builder.childChannelConfigurer(configurer)); + assertSame(configurer, builder.channelConfigurer); + } + + @Test + public void childChannelConfigurer_propagatesMetricsAndInterceptors_xdsTarget() { + // Setup Mocks + when(mockClientTransportFactory.getScheduledExecutorService()) + .thenReturn(clock.getScheduledExecutorService()); + when(mockClientTransportFactoryBuilder.buildClientTransportFactory()) + .thenReturn(mockClientTransportFactory); + when(mockClientTransportFactory.getSupportedSocketAddressTypes()) + .thenReturn(Collections.singleton(InetSocketAddress.class)); + + MetricSink mockMetricSink = mock(MetricSink.class); + ClientInterceptor mockInterceptor = mock(ClientInterceptor.class); + + // Define the Configurer + ChannelConfigurer configurer = new ChannelConfigurer() { + @Override + public void configureChannelBuilder(ManagedChannelBuilder builder) { + builder.addMetricSink(mockMetricSink); + + InternalManagedChannelBuilder.interceptWithTarget(builder, target -> mockInterceptor); + } + }; + + // Mock NameResolver.Factory to capture Args + NameResolver.Factory mockNameResolverFactory = mock(NameResolver.Factory.class); + when(mockNameResolverFactory.getDefaultScheme()).thenReturn("xds"); + NameResolver mockNameResolver = mock(NameResolver.class); + when(mockNameResolver.getServiceAuthority()).thenReturn("foo.authority"); + ArgumentCaptor argsCaptor = ArgumentCaptor.forClass(NameResolver.Args.class); + when(mockNameResolverFactory.newNameResolver((URI) any(), + argsCaptor.capture())).thenReturn(mockNameResolver); + + // Use the configurer and the mock factory + NameResolverRegistry registry = new NameResolverRegistry(); + registry.register(new NameResolverFactoryToProviderFacade(mockNameResolverFactory)); + + ManagedChannelBuilder parentBuilder = new ManagedChannelImplBuilder( + "xds:///my-service-target", + mockClientTransportFactoryBuilder, + new FixedPortProvider(DUMMY_PORT)) + .childChannelConfigurer(configurer) + .nameResolverRegistry(registry); + + ManagedChannel channel = parentBuilder.build(); + grpcCleanupRule.register(channel); + + // Verify that newNameResolver was called + verify(mockNameResolverFactory).newNameResolver((URI) any(), any()); + + // Extract the childChannelConfigurer from Args + NameResolver.Args args = argsCaptor.getValue(); + ChannelConfigurer channelConfigurerInArgs = args.getChildChannelConfigurer(); + assertNotNull("Child channel configurer should be present in NameResolver.Args", + channelConfigurerInArgs); + + // Verify the configurer is the one we passed + assertThat(channelConfigurerInArgs).isSameInstanceAs(configurer); + + // Verify the configurer logically applies (by running it on a mock) + ManagedChannelBuilder mockChildBuilder = mock(ManagedChannelBuilder.class); + // Stub addMetricSink to return the builder to avoid generic return type issues + doReturn(mockChildBuilder).when(mockChildBuilder).addMetricSink(any()); + + configurer.configureChannelBuilder(mockChildBuilder); + verify(mockChildBuilder).addMetricSink(mockMetricSink); + } + @Test public void metricSinks() { MetricSink mocksink = mock(MetricSink.class); diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index ae224af27e1..d55288be357 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -303,6 +303,7 @@ public String getPolicyName() { private ArgumentCaptor streamListenerCaptor = ArgumentCaptor.forClass(ClientStreamListener.class); + private void createChannel(ClientInterceptor... interceptors) { createChannel(false, interceptors); } diff --git a/core/src/test/java/io/grpc/internal/MetricRecorderImplTest.java b/core/src/test/java/io/grpc/internal/MetricRecorderImplTest.java index 33bf9bb41e2..715b09a6cc4 100644 --- a/core/src/test/java/io/grpc/internal/MetricRecorderImplTest.java +++ b/core/src/test/java/io/grpc/internal/MetricRecorderImplTest.java @@ -16,6 +16,7 @@ package io.grpc.internal; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyList; @@ -326,4 +327,25 @@ public void recordLongGaugeMismatchedOptionalLabelValues() { callbackCaptor.getValue().run(); registration.close(); } + + @Test + public void noOp_returnsSingleton() { + assertThat(MetricRecorder.noOp()).isSameInstanceAs(MetricRecorder.noOp()); + } + + @Test + public void noOp_methodsDoNotThrow() { + MetricRecorder recorder = MetricRecorder.noOp(); + + recorder.addDoubleCounter(null, 1.0, + null, null); + recorder.addLongCounter(null, 1, + null, null); + recorder.addLongUpDownCounter(null, 1, + null, null); + recorder.recordDoubleHistogram(null, 1.0, + null, null); + recorder.recordLongHistogram(null, 1, + null, null); + } } diff --git a/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java b/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java index 0da51bf47f7..7da696ac49e 100644 --- a/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java +++ b/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java @@ -21,11 +21,13 @@ import com.google.common.annotations.VisibleForTesting; import io.grpc.CallCredentials; import io.grpc.CallOptions; +import io.grpc.ChannelConfigurer; import io.grpc.ChannelCredentials; import io.grpc.ClientCall; import io.grpc.Context; import io.grpc.Grpc; import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; @@ -36,14 +38,18 @@ final class GrpcXdsTransportFactory implements XdsTransportFactory { private final CallCredentials callCredentials; + private final ChannelConfigurer channelConfigurer; - GrpcXdsTransportFactory(CallCredentials callCredentials) { + + GrpcXdsTransportFactory(CallCredentials callCredentials, + ChannelConfigurer channelConfigurer) { this.callCredentials = callCredentials; + this.channelConfigurer = channelConfigurer; } @Override public XdsTransport create(Bootstrapper.ServerInfo serverInfo) { - return new GrpcXdsTransport(serverInfo, callCredentials); + return new GrpcXdsTransport(serverInfo, callCredentials, channelConfigurer); } @VisibleForTesting @@ -75,6 +81,20 @@ public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, CallCredentials call this.callCredentials = callCredentials; } + public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, + CallCredentials callCredentials, + ChannelConfigurer channelConfigurer) { + String target = serverInfo.target(); + ChannelCredentials channelCredentials = (ChannelCredentials) serverInfo.implSpecificConfig(); + ManagedChannelBuilder channelBuilder = Grpc.newChannelBuilder(target, channelCredentials) + .keepAliveTime(5, TimeUnit.MINUTES); + if (channelConfigurer != null) { + channelConfigurer.configureChannelBuilder(channelBuilder); + } + this.channel = channelBuilder.build(); + this.callCredentials = callCredentials; + } + @VisibleForTesting public GrpcXdsTransport(ManagedChannel channel, CallCredentials callCredentials) { this.channel = checkNotNull(channel, "channel"); diff --git a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java index cc5ff128274..06ce7eb6f53 100644 --- a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java @@ -86,7 +86,8 @@ public static XdsClientResult getOrCreate( String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder, CallCredentials transportCallCredentials) { return new XdsClientResult(SharedXdsClientPoolProvider.getDefaultProvider() - .getOrCreate(target, bootstrapInfo, metricRecorder, transportCallCredentials)); + .getOrCreate(target, bootstrapInfo, metricRecorder, transportCallCredentials, + null)); } /** diff --git a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java index 45c379244af..bccf63c475f 100644 --- a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.CallCredentials; +import io.grpc.ChannelConfigurer; import io.grpc.MetricRecorder; import io.grpc.internal.ExponentialBackoffPolicy; import io.grpc.internal.GrpcUtil; @@ -57,6 +58,10 @@ final class SharedXdsClientPoolProvider implements XdsClientPoolFactory { @Nullable private final Bootstrapper bootstrapper; private final Object lock = new Object(); + /* + The first one wins. + Anything with the same target string uses the client created for the first one. + */ private final Map> targetToXdsClientMap = new ConcurrentHashMap<>(); SharedXdsClientPoolProvider() { @@ -88,20 +93,31 @@ public ObjectPool getOrCreate( } else { bootstrapInfo = GrpcBootstrapperImpl.defaultBootstrap(); } - return getOrCreate(target, bootstrapInfo, metricRecorder, transportCallCredentials); + return getOrCreate(target, bootstrapInfo, metricRecorder, transportCallCredentials, + null); } @Override public ObjectPool getOrCreate( String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder) { - return getOrCreate(target, bootstrapInfo, metricRecorder, null); + return getOrCreate(target, bootstrapInfo, metricRecorder, null, + null); + } + + @Override + public ObjectPool getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder, + ChannelConfigurer channelConfigurer) { + return getOrCreate(target, bootstrapInfo, metricRecorder, null, + channelConfigurer); } public ObjectPool getOrCreate( String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder, - CallCredentials transportCallCredentials) { + CallCredentials transportCallCredentials, + ChannelConfigurer channelConfigurer) { ObjectPool ref = targetToXdsClientMap.get(target); if (ref == null) { synchronized (lock) { @@ -109,7 +125,8 @@ public ObjectPool getOrCreate( if (ref == null) { ref = new RefCountedXdsClientObjectPool( - bootstrapInfo, target, metricRecorder, transportCallCredentials); + bootstrapInfo, target, metricRecorder, transportCallCredentials, + channelConfigurer); targetToXdsClientMap.put(target, ref); } } @@ -134,6 +151,7 @@ class RefCountedXdsClientObjectPool implements ObjectPool { private final String target; // The target associated with the xDS client. private final MetricRecorder metricRecorder; private final CallCredentials transportCallCredentials; + private final ChannelConfigurer channelConfigurer; private final Object lock = new Object(); @GuardedBy("lock") private ScheduledExecutorService scheduler; @@ -147,7 +165,7 @@ class RefCountedXdsClientObjectPool implements ObjectPool { @VisibleForTesting RefCountedXdsClientObjectPool( BootstrapInfo bootstrapInfo, String target, MetricRecorder metricRecorder) { - this(bootstrapInfo, target, metricRecorder, null); + this(bootstrapInfo, target, metricRecorder, null, null); } @VisibleForTesting @@ -155,11 +173,13 @@ class RefCountedXdsClientObjectPool implements ObjectPool { BootstrapInfo bootstrapInfo, String target, MetricRecorder metricRecorder, - CallCredentials transportCallCredentials) { + CallCredentials transportCallCredentials, + ChannelConfigurer channelConfigurer) { this.bootstrapInfo = checkNotNull(bootstrapInfo, "bootstrapInfo"); this.target = target; this.metricRecorder = checkNotNull(metricRecorder, "metricRecorder"); this.transportCallCredentials = transportCallCredentials; + this.channelConfigurer = channelConfigurer; } @Override @@ -172,7 +192,7 @@ public XdsClient getObject() { scheduler = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE); metricReporter = new XdsClientMetricReporterImpl(metricRecorder, target); GrpcXdsTransportFactory xdsTransportFactory = - new GrpcXdsTransportFactory(transportCallCredentials); + new GrpcXdsTransportFactory(transportCallCredentials, channelConfigurer); xdsClient = new XdsClientImpl( xdsTransportFactory, diff --git a/xds/src/main/java/io/grpc/xds/XdsClientPoolFactory.java b/xds/src/main/java/io/grpc/xds/XdsClientPoolFactory.java index 6df8d566a7a..5e16d1225fa 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClientPoolFactory.java +++ b/xds/src/main/java/io/grpc/xds/XdsClientPoolFactory.java @@ -16,6 +16,7 @@ package io.grpc.xds; +import io.grpc.ChannelConfigurer; import io.grpc.MetricRecorder; import io.grpc.internal.ObjectPool; import io.grpc.xds.client.Bootstrapper.BootstrapInfo; @@ -30,5 +31,9 @@ interface XdsClientPoolFactory { ObjectPool getOrCreate( String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder); + ObjectPool getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder, + ChannelConfigurer channelConfigurer); + List getTargets(); } diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index 196d51fb5a6..5c27263f6f2 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -31,6 +31,7 @@ import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Channel; +import io.grpc.ChannelConfigurer; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptors; @@ -182,7 +183,8 @@ final class XdsNameResolver extends NameResolver { } else { checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); this.xdsClientPool = new BootstrappingXdsClientPool( - xdsClientPoolFactory, target, bootstrapOverride, metricRecorder); + xdsClientPoolFactory, target, bootstrapOverride, metricRecorder, + nameResolverArgs.getChildChannelConfigurer()); } this.random = checkNotNull(random, "random"); this.filterRegistry = checkNotNull(filterRegistry, "filterRegistry"); @@ -1054,16 +1056,19 @@ private static final class BootstrappingXdsClientPool implements XdsClientPool { private final @Nullable Map bootstrapOverride; private final @Nullable MetricRecorder metricRecorder; private ObjectPool xdsClientPool; + private ChannelConfigurer channelConfigurer; BootstrappingXdsClientPool( XdsClientPoolFactory xdsClientPoolFactory, String target, @Nullable Map bootstrapOverride, - @Nullable MetricRecorder metricRecorder) { + @Nullable MetricRecorder metricRecorder, + @Nullable ChannelConfigurer channelConfigurer) { this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); this.target = checkNotNull(target, "target"); this.bootstrapOverride = bootstrapOverride; this.metricRecorder = metricRecorder; + this.channelConfigurer = channelConfigurer; } @Override @@ -1076,7 +1081,8 @@ public XdsClient getObject() throws XdsInitializationException { bootstrapInfo = new GrpcBootstrapperImpl().bootstrap(bootstrapOverride); } this.xdsClientPool = - xdsClientPoolFactory.getOrCreate(target, bootstrapInfo, metricRecorder); + xdsClientPoolFactory.getOrCreate( + target, bootstrapInfo, metricRecorder, channelConfigurer); } return xdsClientPool.getObject(); } diff --git a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java index 4a4fb71aa84..febc9567829 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java @@ -25,6 +25,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.errorprone.annotations.DoNotCall; import io.grpc.Attributes; +import io.grpc.ChannelConfigurer; import io.grpc.ExperimentalApi; import io.grpc.ForwardingServerBuilder; import io.grpc.Internal; @@ -58,6 +59,7 @@ public final class XdsServerBuilder extends ForwardingServerBuilder bootstrapOverride; private long drainGraceTime = 10; private TimeUnit drainGraceTimeUnit = TimeUnit.MINUTES; + private ChannelConfigurer channelConfigurer = new ChannelConfigurer() {}; private XdsServerBuilder(NettyServerBuilder nettyDelegate, int port) { this.delegate = nettyDelegate; @@ -100,6 +102,20 @@ public XdsServerBuilder drainGraceTime(long drainGraceTime, TimeUnit drainGraceT return this; } + /** + * Sets the configurer that will be stored in the server built by this builder. + * + *

This configurer will subsequently be used to configure any child channels + * created by that server. + * + * @param channelConfigurer the configurer to store in the channel. + */ + @Override + public XdsServerBuilder childChannelConfigurer(ChannelConfigurer channelConfigurer) { + this.channelConfigurer = channelConfigurer; + return this; + } + @DoNotCall("Unsupported. Use forPort(int, ServerCredentials) instead") public static ServerBuilder forPort(int port) { throw new UnsupportedOperationException( @@ -128,7 +144,8 @@ public Server build() { } InternalNettyServerBuilder.eagAttributes(delegate, builder.build()); return new XdsServerWrapper("0.0.0.0:" + port, delegate, xdsServingStatusListener, - filterChainSelectorManager, xdsClientPoolFactory, bootstrapOverride, filterRegistry); + filterChainSelectorManager, xdsClientPoolFactory, bootstrapOverride, filterRegistry, + this.channelConfigurer); } @VisibleForTesting diff --git a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java index 5529f96c7a2..7e633d75242 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java @@ -29,6 +29,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; import io.grpc.Attributes; +import io.grpc.ChannelConfigurer; import io.grpc.InternalServerInterceptors; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -128,6 +129,8 @@ public void uncaughtException(Thread t, Throwable e) { // NamedFilterConfig.filterStateKey -> filter_instance. private final HashMap activeFiltersDefaultChain = new HashMap<>(); + private ChannelConfigurer channelConfigurer = new ChannelConfigurer() {}; + XdsServerWrapper( String listenerAddress, ServerBuilder delegateBuilder, @@ -148,6 +151,30 @@ public void uncaughtException(Thread t, Throwable e) { sharedTimeService = true; } + XdsServerWrapper( + String listenerAddress, + ServerBuilder delegateBuilder, + XdsServingStatusListener listener, + FilterChainSelectorManager filterChainSelectorManager, + XdsClientPoolFactory xdsClientPoolFactory, + @Nullable Map bootstrapOverride, + FilterRegistry filterRegistry, + ChannelConfigurer channelConfigurer) { + this( + listenerAddress, + delegateBuilder, + listener, + filterChainSelectorManager, + xdsClientPoolFactory, + bootstrapOverride, + filterRegistry, + SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE)); + sharedTimeService = true; + if (channelConfigurer != null) { + this.channelConfigurer = channelConfigurer; + } + } + @VisibleForTesting XdsServerWrapper( String listenerAddress, @@ -202,7 +229,8 @@ private void internalStart() { bootstrapInfo = new GrpcBootstrapperImpl().bootstrap(bootstrapOverride); } xdsClientPool = xdsClientPoolFactory.getOrCreate( - "#server", bootstrapInfo, new MetricRecorder() {}); + "#server", bootstrapInfo, new MetricRecorder() {}, + channelConfigurer); } catch (Exception e) { StatusException statusException = Status.UNAVAILABLE.withDescription( "Failed to initialize xDS").withCause(e).asException(); diff --git a/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java b/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java index e8bd7461736..70ddd8ba795 100644 --- a/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java +++ b/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java @@ -37,6 +37,7 @@ import io.envoyproxy.envoy.service.status.v3.ClientStatusRequest; import io.envoyproxy.envoy.service.status.v3.ClientStatusResponse; import io.envoyproxy.envoy.type.matcher.v3.NodeMatcher; +import io.grpc.ChannelConfigurer; import io.grpc.Deadline; import io.grpc.InsecureChannelCredentials; import io.grpc.MetricRecorder; @@ -517,5 +518,12 @@ public ObjectPool getOrCreate( String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder) { throw new UnsupportedOperationException("Should not be called"); } + + @Override + public ObjectPool getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder, + ChannelConfigurer channelConfigurer) { + throw new UnsupportedOperationException("Should not be called"); + } } } diff --git a/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java b/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java index c8f2b8932ef..1de47cf2b73 100644 --- a/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java +++ b/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java @@ -22,6 +22,13 @@ import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS; import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS; import static org.junit.Assert.assertEquals; +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; import com.github.xds.type.v3.TypedStruct; import com.google.common.collect.ImmutableMap; @@ -47,15 +54,23 @@ import io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3.WrrLocality; import io.grpc.CallOptions; import io.grpc.Channel; +import io.grpc.ChannelConfigurer; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.ClientStreamTracer; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.MetricSink; +import io.grpc.NoopMetricSink; +import io.grpc.Server; import io.grpc.testing.protobuf.SimpleRequest; import io.grpc.testing.protobuf.SimpleResponse; import io.grpc.testing.protobuf.SimpleServiceGrpc; @@ -339,4 +354,63 @@ public void pingPong_logicalDns_authorityOverride() { System.clearProperty("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE"); } } + + @Test + public void childChannelConfigurer_passesMetricSinkToChannel_E2E() { + MetricSink mockSink = mock(MetricSink.class, delegatesTo(new NoopMetricSink())); + ChannelConfigurer configurer = new ChannelConfigurer() { + @Override + public void configureChannelBuilder(ManagedChannelBuilder builder) { + builder.addMetricSink(mockSink); + } + }; + + ManagedChannel channel = Grpc.newChannelBuilder("test-xds:///test-server", + InsecureChannelCredentials.create()) + .childChannelConfigurer(configurer) + .build(); + + try { + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( + channel); + blockingStub.unaryRpc(SimpleRequest.getDefaultInstance()); + + // The xDS client inside the channel configurer will have created an ADS stream. + // The metric sink should have received attempt or connection metrics. + verify(mockSink, timeout(5000).atLeastOnce()) + .addLongCounter(any(), anyLong(), anyList(), anyList()); + } finally { + channel.shutdownNow(); + } + } + + @Test + public void childChannelConfigurer_passesMetricSinkToServer_E2E() throws Exception { + MetricSink mockSink = mock(MetricSink.class, delegatesTo(new NoopMetricSink())); + ChannelConfigurer configurer = new ChannelConfigurer() { + @Override + public void configureChannelBuilder(ManagedChannelBuilder builder) { + // Child channels (xDS client connections) created by this server get the sink. + builder.addMetricSink(mockSink); + } + }; + + // We start an XdsServer manually. + // XdsServer needs RDS, LDS, etc. from control plane. + XdsServerBuilder serverBuilder = XdsServerBuilder.forPort( + 0, InsecureServerCredentials.create()) + .addService(new SimpleServiceGrpc.SimpleServiceImplBase() {}) + .overrideBootstrapForTest(controlPlane.defaultBootstrapOverride()) + .childChannelConfigurer(configurer); + + Server childServer = serverBuilder.build().start(); + + try { + // The server xDS client will connect to control plane to get LDS. + verify(mockSink, timeout(5000).atLeastOnce()) + .addLongCounter(any(), anyLong(), anyList(), anyList()); + } finally { + childServer.shutdownNow(); + } + } } diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java index af55e572811..1e8d3cd29d3 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java @@ -5126,7 +5126,7 @@ public void serverFailureMetricReport_forRetryAndBackoff() { private XdsClientImpl createXdsClient(String serverUri) { BootstrapInfo bootstrapInfo = buildBootStrap(serverUri); return new XdsClientImpl( - new GrpcXdsTransportFactory(null), + new GrpcXdsTransportFactory(null, null), bootstrapInfo, fakeClock.getScheduledExecutorService(), backoffPolicyProvider, diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java index 66e0d4b3198..a97490ed9bb 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java @@ -17,15 +17,24 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import com.google.common.util.concurrent.SettableFuture; import io.envoyproxy.envoy.service.discovery.v3.AggregatedDiscoveryServiceGrpc; import io.envoyproxy.envoy.service.discovery.v3.DiscoveryRequest; import io.envoyproxy.envoy.service.discovery.v3.DiscoveryResponse; import io.grpc.BindableService; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ChannelConfigurer; +import io.grpc.ClientInterceptor; import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; import io.grpc.InsecureServerCredentials; +import io.grpc.ManagedChannelBuilder; import io.grpc.MethodDescriptor; import io.grpc.Server; import io.grpc.Status; @@ -92,7 +101,7 @@ public void onCompleted() { @Test public void callApis() throws Exception { XdsTransportFactory.XdsTransport xdsTransport = - new GrpcXdsTransportFactory(null) + new GrpcXdsTransportFactory(null, null) .create( Bootstrapper.ServerInfo.create( "localhost:" + server.getPort(), InsecureChannelCredentials.create())); @@ -139,5 +148,110 @@ public void onStatusReceived(Status status) { endFuture.set(status); } } + + @Test + @SuppressWarnings("unchecked") + public void verifyConfigApplied_interceptor() { + // Create a mock Interceptor + final ClientInterceptor mockInterceptor = mock(ClientInterceptor.class); + when(mockInterceptor.interceptCall(any(MethodDescriptor.class), + any(CallOptions.class), any(Channel.class))) + .thenReturn(new io.grpc.NoopClientCall<>()); + + // Create Configurer that adds the interceptor + ChannelConfigurer configurer = new ChannelConfigurer() { + @Override + public void configureChannelBuilder(ManagedChannelBuilder builder) { + builder.intercept(mockInterceptor); + } + }; + + // Create Factory + GrpcXdsTransportFactory factory = new GrpcXdsTransportFactory( + null, + configurer); + + // Create Transport + XdsTransportFactory.XdsTransport transport = factory.create( + Bootstrapper.ServerInfo.create("localhost:8080", InsecureChannelCredentials.create())); + + // Create a Call to trigger interceptors + MethodDescriptor method = MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName("service/method") + .setFullMethodName("service/method") + .setRequestMarshaller(mock(MethodDescriptor.Marshaller.class)) + .setResponseMarshaller(mock(MethodDescriptor.Marshaller.class)) + .build(); + + transport.createStreamingCall(method.getFullMethodName(), method.getRequestMarshaller(), + method.getResponseMarshaller()); + + // Verify interceptor was invoked + verify(mockInterceptor).interceptCall(any(MethodDescriptor.class), + any(CallOptions.class), any(Channel.class)); + + transport.shutdown(); + } + + @Test + public void useChannelConfigurer() { + // Mock Configurer + ChannelConfigurer mockConfigurer = mock(ChannelConfigurer.class); + + // Create Factory + GrpcXdsTransportFactory factory = new GrpcXdsTransportFactory( + null, // CallCredentials + mockConfigurer); + + // Create Transport (triggers channel creation) + XdsTransportFactory.XdsTransport transport = factory.create( + Bootstrapper.ServerInfo.create("localhost:8080", InsecureChannelCredentials.create())); + + // Verify Configurer was accessed and applied + verify(mockConfigurer).configureChannelBuilder(any(ManagedChannelBuilder.class)); + + transport.shutdown(); + } + + @Test + public void verifyConfigApplied_maxInboundMessageSize() { + // Create a mock Builder + ManagedChannelBuilder mockBuilder = mock(ManagedChannelBuilder.class); + + // Create Configurer that modifies message size + ChannelConfigurer configurer = new ChannelConfigurer() { + @Override + public void configureChannelBuilder(ManagedChannelBuilder builder) { + builder.maxInboundMessageSize(1024); + } + }; + + // Apply configurer to builder + configurer.configureChannelBuilder(mockBuilder); + + // Verify builder was modified + verify(mockBuilder).maxInboundMessageSize(1024); + } + + @Test + public void verifyConfigApplied_interceptors() { + ClientInterceptor interceptor1 = mock(ClientInterceptor.class); + ClientInterceptor interceptor2 = mock(ClientInterceptor.class); + + ChannelConfigurer configurer = new ChannelConfigurer() { + @Override + public void configureChannelBuilder(ManagedChannelBuilder builder) { + builder.intercept(interceptor1); + builder.intercept(interceptor2); + } + }; + + ManagedChannelBuilder mockBuilder = mock(ManagedChannelBuilder.class); + configurer.configureChannelBuilder(mockBuilder); + + verify(mockBuilder).intercept(interceptor1); + verify(mockBuilder).intercept(interceptor2); + } } diff --git a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java index 9bdf86132b6..bfc7db19750 100644 --- a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java +++ b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java @@ -181,7 +181,7 @@ public void cancelled(Context context) { lrsClient = new LoadReportClient( loadStatsManager, - new GrpcXdsTransportFactory(null).createForTest(channel), + new GrpcXdsTransportFactory(null, null).createForTest(channel), NODE, syncContext, fakeClock.getScheduledExecutorService(), diff --git a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java index 29b149f166f..89a094046fe 100644 --- a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java @@ -28,9 +28,12 @@ import com.google.auth.oauth2.OAuth2Credentials; import com.google.common.util.concurrent.SettableFuture; import io.grpc.CallCredentials; +import io.grpc.ChannelConfigurer; +import io.grpc.ClientInterceptor; import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; import io.grpc.InsecureServerCredentials; +import io.grpc.ManagedChannelBuilder; import io.grpc.Metadata; import io.grpc.MetricRecorder; import io.grpc.Server; @@ -207,7 +210,8 @@ public void xdsClient_usesCallCredentials() throws Exception { // Create xDS client that uses the CallCredentials on the transport ObjectPool xdsClientPool = - provider.getOrCreate("target", bootstrapInfo, metricRecorder, sampleCreds); + provider.getOrCreate("target", bootstrapInfo, metricRecorder, sampleCreds, + null); XdsClient xdsClient = xdsClientPool.getObject(); xdsClient.watchXdsResource( XdsListenerResource.getInstance(), "someLDSresource", ldsResourceWatcher); @@ -220,4 +224,64 @@ public void xdsClient_usesCallCredentials() throws Exception { xdsClientPool.returnObject(xdsClient); xdsServer.shutdownNow(); } + + @Test + public void xdsClient_usesChannelConfigurer() throws Exception { + // Set up fake xDS server + XdsTestControlPlaneService fakeXdsService = new XdsTestControlPlaneService(); + CallCredsServerInterceptor callInterceptor = new CallCredsServerInterceptor(); + Server xdsServer = + Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) + .addService(fakeXdsService) + .intercept(callInterceptor) + .build() + .start(); + String xdsServerUri = "localhost:" + xdsServer.getPort(); + + // Set up bootstrap & xDS client pool provider + ServerInfo server = ServerInfo.create(xdsServerUri, InsecureChannelCredentials.create()); + BootstrapInfo bootstrapInfo = + BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); + SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(); + + // Create a client interceptor that actually just injects a test token + ClientInterceptor testInterceptor = new ClientInterceptor() { + @Override + public io.grpc.ClientCall interceptCall( + io.grpc.MethodDescriptor method, + io.grpc.CallOptions callOptions, + io.grpc.Channel next) { + return new io.grpc.ForwardingClientCall.SimpleForwardingClientCall( + next.newCall(method, callOptions)) { + @Override + public void start(Listener responseListener, Metadata headers) { + headers.put(AUTHORIZATION_METADATA_KEY, "Bearer test-configurer-token"); + super.start(responseListener, headers); + } + }; + } + }; + + ChannelConfigurer configurer = new ChannelConfigurer() { + @Override + public void configureChannelBuilder(ManagedChannelBuilder builder) { + builder.intercept(testInterceptor); + } + }; + + // Create xDS client that uses the ChannelConfigurer on the transport + ObjectPool xdsClientPool = + provider.getOrCreate("target", bootstrapInfo, metricRecorder, null, configurer); + XdsClient xdsClient = xdsClientPool.getObject(); + xdsClient.watchXdsResource( + XdsListenerResource.getInstance(), "someLDSresource", ldsResourceWatcher); + + // Wait for xDS server to get the request and verify that it received the token from configurer + assertThat(callInterceptor.getTokenWithTimeout(5, TimeUnit.SECONDS)) + .isEqualTo("Bearer test-configurer-token"); + + // Clean up + xdsClientPool.returnObject(xdsClient); + xdsServer.shutdownNow(); + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java b/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java index 27ee8d22825..4d5e7d09ad4 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java @@ -484,7 +484,7 @@ public void fallbackFromBadUrlToGoodOne() { XdsClientImpl client = CommonBootstrapperTestUtils.createXdsClient( Arrays.asList(garbageUri, validUri), - new GrpcXdsTransportFactory(null), + new GrpcXdsTransportFactory(null, null), fakeClock, new ExponentialBackoffPolicy.Provider(), MessagePrinter.INSTANCE, @@ -509,7 +509,7 @@ public void testGoodUrlFollowedByBadUrl() { XdsClientImpl client = CommonBootstrapperTestUtils.createXdsClient( Arrays.asList(validUri, garbageUri), - new GrpcXdsTransportFactory(null), + new GrpcXdsTransportFactory(null, null), fakeClock, new ExponentialBackoffPolicy.Provider(), MessagePrinter.INSTANCE, @@ -536,7 +536,7 @@ public void testTwoBadUrl() { XdsClientImpl client = CommonBootstrapperTestUtils.createXdsClient( Arrays.asList(garbageUri1, garbageUri2), - new GrpcXdsTransportFactory(null), + new GrpcXdsTransportFactory(null, null), fakeClock, new ExponentialBackoffPolicy.Provider(), MessagePrinter.INSTANCE, diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java index 45a96ee172f..af65c4d9687 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java @@ -27,6 +27,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.lenient; @@ -47,6 +48,7 @@ import com.google.re2j.Pattern; import io.grpc.CallOptions; import io.grpc.Channel; +import io.grpc.ChannelConfigurer; import io.grpc.ChannelLogger; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; @@ -67,6 +69,7 @@ import io.grpc.NameResolver.ServiceConfigParser; import io.grpc.NoopClientCall; import io.grpc.NoopClientCall.NoopClientCallListener; +import io.grpc.ProxyDetector; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StatusOr; @@ -2493,6 +2496,24 @@ public XdsClient returnObject(Object object) { }; } + @Override + public ObjectPool getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder, + ChannelConfigurer channelConfigurer) { + targets.add(target); + return new ObjectPool() { + @Override + public XdsClient getObject() { + return xdsClient; + } + + @Override + public XdsClient returnObject(Object object) { + return null; + } + }; + } + @Override public List getTargets() { if (targets.isEmpty()) { @@ -2931,4 +2952,60 @@ void deliverErrorStatus() { listener.onClose(Status.UNAVAILABLE, new Metadata()); } } + + @Test + public void start_passesParentChannelToClientPoolFactory() { + ChannelConfigurer mockChannelConfigurer = mock(ChannelConfigurer.class); + + // Build NameResolver.Args containing the child channel configurer + NameResolver.Args args = NameResolver.Args.newBuilder() + .setDefaultPort(8080) + .setProxyDetector(mock(ProxyDetector.class)) + .setSynchronizationContext(syncContext) + .setServiceConfigParser(serviceConfigParser) + .setChannelLogger(mock(ChannelLogger.class)) + .setChildChannelConfigurer(mockChannelConfigurer) + .build(); + + // Mock the XdsClientPoolFactory + XdsClientPoolFactory mockPoolFactory = mock(XdsClientPoolFactory.class); + @SuppressWarnings("unchecked") + ObjectPool mockObjectPool = mock(ObjectPool.class); + XdsClient mockXdsClient = mock(XdsClient.class); + when(mockObjectPool.getObject()).thenReturn(mockXdsClient); + when(mockXdsClient.getBootstrapInfo()).thenReturn(bootstrapInfo); + + when(mockPoolFactory.getOrCreate( + anyString(), + any(BootstrapInfo.class), + any(MetricRecorder.class), + any(ChannelConfigurer.class))) + .thenReturn(mockObjectPool); + + XdsNameResolver resolver = new XdsNameResolver( + URI.create(AUTHORITY), + null, // targetAuthority (nullable) + AUTHORITY, // name + null, // overrideAuthority (nullable) + serviceConfigParser, + syncContext, + scheduler, + mockPoolFactory, + mockRandom, + FilterRegistry.getDefaultRegistry(), + rawBootstrap, + metricRecorder, + args); + + // Start the resolver (this triggers the factory call) + resolver.start(mockListener); + + verify(mockPoolFactory).getOrCreate( + eq(AUTHORITY), + any(BootstrapInfo.class), + eq(metricRecorder), + eq(mockChannelConfigurer)); + + resolver.shutdown(); + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java index ac990226259..770bdae1eac 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java @@ -30,15 +30,18 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.SettableFuture; import io.grpc.BindableService; +import io.grpc.ChannelConfigurer; import io.grpc.InsecureServerCredentials; import io.grpc.ServerServiceDefinition; import io.grpc.Status; import io.grpc.StatusException; import io.grpc.StatusOr; +import io.grpc.internal.ObjectPool; import io.grpc.testing.GrpcCleanupRule; import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; +import io.grpc.xds.client.XdsClient; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; import java.io.IOException; import java.net.InetSocketAddress; @@ -321,8 +324,29 @@ public void testOverrideBootstrap() throws Exception { buildBuilder(null); builder.overrideBootstrapForTest(b); xdsServer = cleanupRule.register((XdsServerWrapper) builder.build()); - Future unused = startServerAsync(); + Future unused = startServerAsync(); assertThat(xdsClientPoolFactory.savedBootstrapInfo.node().getId()) .isEqualTo(XdsServerTestHelper.BOOTSTRAP_INFO.node().getId()); } + + @Test + public void start_passesParentServerToClientPoolFactory() throws Exception { + ChannelConfigurer mockConfigurer = mock(ChannelConfigurer.class); + XdsClientPoolFactory mockPoolFactory = mock(XdsClientPoolFactory.class); + @SuppressWarnings("unchecked") + ObjectPool mockPool = mock(ObjectPool.class); + when(mockPool.getObject()).thenReturn(xdsClient); + when(mockPoolFactory.getOrCreate(any(), any(), any(), any())).thenReturn(mockPool); + + buildBuilder(null); + builder.childChannelConfigurer(mockConfigurer); + builder.xdsClientPoolFactory(mockPoolFactory); + xdsServer = cleanupRule.register((XdsServerWrapper) builder.build()); + + Future unused = startServerAsync(); + + // Verify getOrCreate called with the server instance + verify(mockPoolFactory).getOrCreate( + any(), any(), any(), org.mockito.ArgumentMatchers.eq(mockConfigurer)); + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java index 386793299d8..56d5021691a 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.SettableFuture; import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; +import io.grpc.ChannelConfigurer; import io.grpc.InsecureChannelCredentials; import io.grpc.MetricRecorder; import io.grpc.Status; @@ -182,6 +183,25 @@ public XdsClient returnObject(Object object) { }; } + @Override + public ObjectPool getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder, + ChannelConfigurer channelConfigurer) { + this.savedBootstrapInfo = bootstrapInfo; + return new ObjectPool() { + @Override + public XdsClient getObject() { + return xdsClient; + } + + @Override + public XdsClient returnObject(Object object) { + xdsClient.shutdown(); + return null; + } + }; + } + @Override public List getTargets() { return Collections.singletonList("fake-target");