diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java index f21a9b99f1..f31794ac36 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java @@ -282,6 +282,15 @@ static GrpcTlsConfig tlsConf(Parameters parameters) { static void setTlsConf(Parameters parameters, GrpcTlsConfig conf) { parameters.put(TLS_CONF_PARAMETER, conf, TLS_CONF_CLASS); } + + String STUB_POOL_SIZE_KEY = PREFIX + ".stub.pool.size"; + int STUB_POOL_SIZE_DEFAULT = 1; + static int stubPoolSize(RaftProperties properties) { + return get(properties::getInt, STUB_POOL_SIZE_KEY, STUB_POOL_SIZE_DEFAULT, getDefaultLog()); + } + static void setStubPoolSize(RaftProperties properties, int size) { + setInt(properties::setInt, STUB_POOL_SIZE_KEY, size); + } } String MESSAGE_SIZE_MAX_KEY = PREFIX + ".message.size.max"; diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java index 1e40a75ada..d2748c7be2 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java @@ -45,6 +45,7 @@ class GrpcServerProtocolClient implements Closeable { // Common channel private final ManagedChannel channel; + private final GrpcStubPool pool; // Channel and stub for heartbeat private ManagedChannel hbChannel; private RaftServerProtocolServiceStub hbAsyncStub; @@ -57,7 +58,7 @@ class GrpcServerProtocolClient implements Closeable { //visible for using in log / error messages AND to use in instrumented tests private final RaftPeerId raftPeerId; - GrpcServerProtocolClient(RaftPeer target, int flowControlWindow, + GrpcServerProtocolClient(RaftPeer target, int connections, int flowControlWindow, TimeDuration requestTimeout, SslContext sslContext, boolean separateHBChannel) { raftPeerId = target.getId(); LOG.info("Build channel for {}", target); @@ -70,6 +71,11 @@ class GrpcServerProtocolClient implements Closeable { hbAsyncStub = RaftServerProtocolServiceGrpc.newStub(hbChannel); } requestTimeoutDuration = requestTimeout; + this.pool = connections == 1? null : newGrpcStubPool(target.getAddress(), sslContext, connections); + } + + GrpcStubPool newGrpcStubPool(String address, SslContext sslContext, int connections) { + return new GrpcStubPool<>(connections, address, sslContext, RaftServerProtocolServiceGrpc::newStub, 16); } private ManagedChannel buildChannel(RaftPeer target, int flowControlWindow, SslContext sslContext) { @@ -94,6 +100,9 @@ public void close() { GrpcUtil.shutdownManagedChannel(hbChannel); } GrpcUtil.shutdownManagedChannel(channel); + if (pool != null) { + pool.close(); + } } public RequestVoteReplyProto requestVote(RequestVoteRequestProto request) { @@ -112,8 +121,44 @@ public StartLeaderElectionReplyProto startLeaderElection(StartLeaderElectionRequ } void readIndex(ReadIndexRequestProto request, StreamObserver s) { - asyncStub.withDeadlineAfter(requestTimeoutDuration.getDuration(), requestTimeoutDuration.getUnit()) - .readIndex(request, s); + if (pool == null) { + asyncStub.withDeadlineAfter(requestTimeoutDuration.getDuration(), requestTimeoutDuration.getUnit()) + .readIndex(request, s); + } else { + GrpcStubPool.Stub p; + try { + p = pool.acquire(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + s.onError(e); + return; + } + p.getStub().withDeadlineAfter(requestTimeoutDuration.getDuration(), requestTimeoutDuration.getUnit()) + .readIndex(request, new StreamObserver() { + @Override + public void onNext(ReadIndexReplyProto v) { + s.onNext(v); + } + + @Override + public void onError(Throwable t) { + try { + s.onError(t); + } finally { + p.release(); + } + } + + @Override + public void onCompleted() { + try { + s.onCompleted(); + } finally { + p.release(); + } + } + }); + } } CallStreamObserver appendEntries( diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java index 8200aa3ef7..b1af0960dc 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java @@ -108,6 +108,7 @@ public static final class Builder { private int serverPort; private SslContext serverSslContextForServer; private SslContext serverSslContextForClient; + private int serverStubPoolSize; private SizeInBytes messageSizeMax; private SizeInBytes flowControlWindow; @@ -130,6 +131,7 @@ public Builder setServer(RaftServer raftServer) { this.flowControlWindow = GrpcConfigKeys.flowControlWindow(properties, LOG::info); this.requestTimeoutDuration = RaftServerConfigKeys.Rpc.requestTimeout(properties); this.separateHeartbeatChannel = GrpcConfigKeys.Server.heartbeatChannel(properties); + this.serverStubPoolSize = GrpcConfigKeys.Server.stubPoolSize(properties); final SizeInBytes appenderBufferSize = RaftServerConfigKeys.Log.Appender.bufferByteLimit(properties); final SizeInBytes gap = SizeInBytes.ONE_MB; @@ -150,7 +152,7 @@ public Builder setCustomizer(Customizer customizer) { } private GrpcServerProtocolClient newGrpcServerProtocolClient(RaftPeer target) { - return new GrpcServerProtocolClient(target, flowControlWindow.getSizeInt(), + return new GrpcServerProtocolClient(target, serverStubPoolSize, flowControlWindow.getSizeInt(), requestTimeoutDuration, serverSslContextForClient, separateHeartbeatChannel); } diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcStubPool.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcStubPool.java new file mode 100644 index 0000000000..fd27ac996a --- /dev/null +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcStubPool.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ratis.grpc.server; + +import org.apache.ratis.thirdparty.io.grpc.ManagedChannel; +import org.apache.ratis.thirdparty.io.grpc.netty.NegotiationType; +import org.apache.ratis.thirdparty.io.grpc.netty.NettyChannelBuilder; +import org.apache.ratis.thirdparty.io.grpc.stub.AbstractStub; +import org.apache.ratis.thirdparty.io.netty.channel.ChannelOption; +import org.apache.ratis.thirdparty.io.netty.channel.WriteBufferWaterMark; +import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext; +import org.apache.ratis.util.MemoizedSupplier; +import org.apache.ratis.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +final class GrpcStubPool> { + public static final Logger LOG = LoggerFactory.getLogger(GrpcStubPool.class); + + static ManagedChannel buildManagedChannel(String address, SslContext sslContext) { + NettyChannelBuilder channelBuilder = NettyChannelBuilder.forTarget(address) + .keepAliveTime(10, TimeUnit.MINUTES) + .keepAliveWithoutCalls(false) + .idleTimeout(30, TimeUnit.MINUTES) + .withOption(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(64 << 10, 128 << 10)); + if (sslContext != null) { + LOG.debug("Setting TLS for {}", address); + channelBuilder.useTransportSecurity().sslContext(sslContext); + } else { + channelBuilder.negotiationType(NegotiationType.PLAINTEXT); + } + ManagedChannel ch = channelBuilder.build(); + ch.getState(true); + return ch; + } + + static final class Stub> { + private final ManagedChannel ch; + private final S stub; + private final Semaphore permits; + + Stub(String address, SslContext sslContext, Function stubFactory, int maxInflight) { + this.ch = buildManagedChannel(address, sslContext); + this.stub = stubFactory.apply(ch); + this.permits = new Semaphore(maxInflight); + } + + S getStub() { + return stub; + } + + void release() { + permits.release(); + } + + void shutdown() { + ch.shutdown(); + } + } + + private final List>> pool; + + GrpcStubPool(int connections, String address, SslContext sslContext, Function stubFactory, + int maxInflightPerConn) { + Preconditions.assertTrue(connections > 1, "connections must be > 1"); + final List>> tmpPool = new ArrayList<>(connections); + for (int i = 0; i < connections; i++) { + tmpPool.add(MemoizedSupplier.valueOf(() -> new Stub<>(address, sslContext, stubFactory, maxInflightPerConn))); + } + this.pool = Collections.unmodifiableList(tmpPool); + } + + Stub getStub(int i) { + return pool.get(i).get(); + } + + Stub acquire() throws InterruptedException { + final int size = pool.size(); + final int start = ThreadLocalRandom.current().nextInt(size); + for (int k = 0; k < size; k++) { + Stub p = getStub((start + k) % size); + if (p.permits.tryAcquire()) { + return p; + } + } + final Stub p = getStub(start); + p.permits.acquire(); + return p; + } + + public void close() { + for (MemoizedSupplier> p : pool) { + if (p.isInitialized()) { + p.get().shutdown(); + } + } + } +}