Skip to content

Commit d30c02e

Browse files
HolyLowRexXiong
authored andcommitted
[CELEBORN-2235][CIP-14] Adapt Java end's serialization to CppWriterClient
### What changes were proposed in this pull request? This PR adapts Java end's serialization to CppWriterClient, including RegisterShuffle/Response, Revive/Response, MapperEnd/Response. Joint test for cpp-write java-read procedure is included as well. ### Why are the changes needed? Support writing to Celeborn server with CppWriterClient. ### Does this PR resolve a correctness bug? No. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Compilation and integration tests. Closes #3561 from HolyLow/issue/celeborn-2235-adapt-java-to-cpp-writer-serialization. Authored-by: HolyLow <[email protected]> Signed-off-by: Shuang <[email protected]>
1 parent 2dd1b7a commit d30c02e

File tree

16 files changed

+517
-188
lines changed

16 files changed

+517
-188
lines changed

.github/workflows/cpp_integration.yml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,24 +85,31 @@ jobs:
8585
check-latest: false
8686
- name: Compile & Install Celeborn Java
8787
run: build/mvn clean install -DskipTests
88-
- name: Run Java-Cpp Hybrid Integration Test
88+
- name: Run Java-Write Cpp-Read Hybrid Integration Test (NONE Decompression)
8989
run: |
9090
build/mvn -pl worker \
9191
test-compile exec:java \
9292
-Dexec.classpathScope="test" \
9393
-Dexec.mainClass="org.apache.celeborn.service.deploy.cluster.JavaWriteCppReadTestWithNONE" \
9494
-Dexec.args="-XX:MaxDirectMemorySize=2G"
95-
- name: Run Java-Cpp Hybrid Integration Test (LZ4 Decompression)
95+
- name: Run Java-Write Cpp-Read Hybrid Integration Test (LZ4 Decompression)
9696
run: |
9797
build/mvn -pl worker \
9898
test-compile exec:java \
9999
-Dexec.classpathScope="test" \
100100
-Dexec.mainClass="org.apache.celeborn.service.deploy.cluster.JavaWriteCppReadTestWithLZ4" \
101101
-Dexec.args="-XX:MaxDirectMemorySize=2G"
102-
- name: Run Java-Cpp Hybrid Integration Test (ZSTD Decompression)
102+
- name: Run Java-Write Cpp-Read Hybrid Integration Test (ZSTD Decompression)
103103
run: |
104104
build/mvn -pl worker \
105105
test-compile exec:java \
106106
-Dexec.classpathScope="test" \
107107
-Dexec.mainClass="org.apache.celeborn.service.deploy.cluster.JavaWriteCppReadTestWithZSTD" \
108108
-Dexec.args="-XX:MaxDirectMemorySize=2G"
109+
- name: Run Cpp-Write Java-Read Hybrid Integration Test (NONE Compression)
110+
run: |
111+
build/mvn -pl worker \
112+
test-compile exec:java \
113+
-Dexec.classpathScope="test" \
114+
-Dexec.mainClass="org.apache.celeborn.service.deploy.cluster.CppWriteJavaReadTestWithNONE" \
115+
-Dexec.args="-XX:MaxDirectMemorySize=2G"

client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import org.apache.celeborn.common.network.client.TransportClient;
4747
import org.apache.celeborn.common.network.client.TransportClientFactory;
4848
import org.apache.celeborn.common.network.protocol.PushData;
49+
import org.apache.celeborn.common.network.protocol.SerdeVersion;
4950
import org.apache.celeborn.common.network.protocol.TransportMessage;
5051
import org.apache.celeborn.common.network.util.TransportConf;
5152
import org.apache.celeborn.common.protocol.MessageType;
@@ -528,7 +529,7 @@ public Optional<PartitionLocation> regionStart(
528529
public Optional<PartitionLocation> revive(
529530
int shuffleId, int mapId, int attemptId, PartitionLocation location)
530531
throws CelebornIOException {
531-
Set<Integer> mapIds = new HashSet<>();
532+
List<Integer> mapIds = new ArrayList<>();
532533
mapIds.add(mapId);
533534
List<ReviveRequest> requests = new ArrayList<>();
534535
ReviveRequest req =
@@ -543,7 +544,7 @@ public Optional<PartitionLocation> revive(
543544
requests.add(req);
544545
PbChangeLocationResponse response =
545546
lifecycleManagerRef.askSync(
546-
ControlMessages.Revive$.MODULE$.apply(shuffleId, mapIds, requests),
547+
ControlMessages.Revive$.MODULE$.apply(shuffleId, mapIds, requests, SerdeVersion.V1),
547548
conf.clientRpcRequestPartitionLocationAskTimeout(),
548549
ClassTag$.MODULE$.apply(PbChangeLocationResponse.class));
549550
// per partitionKey only serve single PartitionLocation in Client Cache.

client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -550,11 +550,11 @@ private ConcurrentHashMap<Integer, PartitionLocation> registerShuffle(
550550
numPartitions,
551551
() ->
552552
lifecycleManagerRef.askSync(
553-
RegisterShuffle$.MODULE$.apply(shuffleId, numMappers, numPartitions),
553+
new RegisterShuffle(shuffleId, numMappers, numPartitions, SerdeVersion.V1),
554554
conf.clientRpcRegisterShuffleAskTimeout(),
555555
rpcMaxRetries,
556556
rpcRetryWait,
557-
ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
557+
ClassTag$.MODULE$.apply(RegisterShuffleResponse.class)));
558558
}
559559

560560
@Override
@@ -593,7 +593,7 @@ public PartitionLocation registerMapPartitionTask(
593593
partitionId,
594594
isSegmentGranularityVisible),
595595
conf.clientRpcRegisterShuffleAskTimeout(),
596-
ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
596+
ClassTag$.MODULE$.apply(RegisterShuffleResponse.class)));
597597

598598
return partitionLocationMap.get(partitionId);
599599
}
@@ -709,23 +709,18 @@ public boolean reportBarrierTaskFailure(int appShuffleId, String appShuffleIdent
709709
}
710710

711711
private ConcurrentHashMap<Integer, PartitionLocation> registerShuffleInternal(
712-
int shuffleId,
713-
int numMappers,
714-
int numPartitions,
715-
Callable<PbRegisterShuffleResponse> callable)
712+
int shuffleId, int numMappers, int numPartitions, Callable<RegisterShuffleResponse> callable)
716713
throws CelebornIOException {
717714
int numRetries = registerShuffleMaxRetries;
718715
StatusCode lastFailedStatusCode = null;
719716
while (numRetries > 0) {
720717
try {
721-
PbRegisterShuffleResponse response = callable.call();
722-
StatusCode respStatus = StatusCode.fromValue(response.getStatus());
718+
RegisterShuffleResponse response = callable.call();
719+
StatusCode respStatus = response.status();
723720
if (StatusCode.SUCCESS.equals(respStatus)) {
724721
ConcurrentHashMap<Integer, PartitionLocation> result = JavaUtils.newConcurrentHashMap();
725-
Tuple2<List<PartitionLocation>, List<PartitionLocation>> locations =
726-
PbSerDeUtils.fromPbPackedPartitionLocationsPair(
727-
response.getPackedPartitionLocationsPair());
728-
for (PartitionLocation location : locations._1) {
722+
PartitionLocation[] locations = response.partitionLocations();
723+
for (PartitionLocation location : locations) {
729724
pushExcludedWorkers.remove(location.hostAndPushPort());
730725
if (location.hasPeer()) {
731726
pushExcludedWorkers.remove(location.getPeer().hostAndPushPort());
@@ -900,43 +895,43 @@ Map<Integer, Integer> reviveBatch(
900895
oldLocMap.put(req.partitionId, req.loc);
901896
}
902897
try {
903-
PbChangeLocationResponse response =
898+
ChangeLocationResponse response =
904899
lifecycleManagerRef.askSync(
905-
Revive$.MODULE$.apply(shuffleId, mapIds, requests),
900+
Revive$.MODULE$.apply(
901+
shuffleId, new ArrayList<>(mapIds), new ArrayList<>(requests), SerdeVersion.V1),
906902
conf.clientRpcRequestPartitionLocationAskTimeout(),
907-
ClassTag$.MODULE$.apply(PbChangeLocationResponse.class));
903+
ClassTag$.MODULE$.apply(ChangeLocationResponse.class));
908904

909-
for (int i = 0; i < response.getEndedMapIdCount(); i++) {
910-
int mapId = response.getEndedMapId(i);
905+
for (int i = 0; i < response.endedMapIds().size(); i++) {
906+
int mapId = response.endedMapIds().get(i);
911907
mapperEndMap.computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet()).add(mapId);
912908
}
913909

914-
for (int i = 0; i < response.getPartitionInfoCount(); i++) {
915-
PbChangeLocationPartitionInfo partitionInfo = response.getPartitionInfo(i);
916-
int partitionId = partitionInfo.getPartitionId();
917-
int statusCode = partitionInfo.getStatus();
918-
if (partitionInfo.getOldAvailable()) {
910+
for (Map.Entry<Integer, Tuple3<StatusCode, Boolean, PartitionLocation>> entry :
911+
response.newLocs().entrySet()) {
912+
int partitionId = entry.getKey();
913+
StatusCode statusCode = entry.getValue()._1();
914+
if (entry.getValue()._2() != null) {
919915
PartitionLocation oldLoc = oldLocMap.get(partitionId);
920916
// Currently, revive only check if main location available, here won't remove peer loc.
921917
pushExcludedWorkers.remove(oldLoc.hostAndPushPort());
922918
}
923919

924-
if (StatusCode.SUCCESS.getValue() == statusCode) {
925-
PartitionLocation loc =
926-
PbSerDeUtils.fromPbPartitionLocation(partitionInfo.getPartition());
920+
if (StatusCode.SUCCESS == statusCode) {
921+
PartitionLocation loc = entry.getValue()._3();
927922
partitionLocationMap.put(partitionId, loc);
928923
pushExcludedWorkers.remove(loc.hostAndPushPort());
929924
if (loc.hasPeer()) {
930925
pushExcludedWorkers.remove(loc.getPeer().hostAndPushPort());
931926
}
932-
} else if (StatusCode.STAGE_ENDED.getValue() == statusCode) {
927+
} else if (StatusCode.STAGE_ENDED == statusCode) {
933928
stageEndShuffleSet.add(shuffleId);
934929
return results;
935-
} else if (StatusCode.SHUFFLE_UNREGISTERED.getValue() == statusCode) {
930+
} else if (StatusCode.SHUFFLE_UNREGISTERED == statusCode) {
936931
logger.error("SHUFFLE_NOT_REGISTERED!");
937932
return null;
938933
}
939-
results.put(partitionId, statusCode);
934+
results.put(partitionId, (int) (statusCode.getValue()));
940935
}
941936

942937
return results;
@@ -1806,7 +1801,8 @@ private void mapEndInternal(
18061801
pushState.getFailedBatches(),
18071802
numPartitions,
18081803
crc32PerPartition,
1809-
bytesPerPartition),
1804+
bytesPerPartition,
1805+
SerdeVersion.V1),
18101806
rpcMaxRetries,
18111807
rpcRetryWait,
18121808
ClassTag$.MODULE$.apply(MapperEndResponse.class));

0 commit comments

Comments
 (0)