Skip to content

Commit 77e64f2

Browse files
committed
Fix a race condition in aggregate ProfileFileSupplier
1 parent c044542 commit 77e64f2

File tree

3 files changed

+85
-28
lines changed

3 files changed

+85
-28
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"type": "bugfix",
3+
"category": "AWS SDK for Java v2",
4+
"contributor": "",
5+
"description": "Fix a race condition in aggregate ProfileFileSupplier that could cause credential resolution failures with shared DefaultCredentialsProvider."
6+
}

core/profiles/src/main/java/software/amazon/awssdk/profiles/ProfileFileSupplier.java

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.function.Supplier;
2727
import software.amazon.awssdk.annotations.SdkPublicApi;
2828
import software.amazon.awssdk.profiles.internal.ProfileFileRefresher;
29+
import software.amazon.awssdk.utils.Pair;
2930

3031
/**
3132
* Encapsulates the logic for supplying either a single or multiple ProfileFile instances.
@@ -127,43 +128,41 @@ static ProfileFileSupplier aggregate(ProfileFileSupplier... suppliers) {
127128

128129
return new ProfileFileSupplier() {
129130

130-
final AtomicReference<ProfileFile> currentAggregateProfileFile = new AtomicReference<>();
131-
final Map<Supplier<ProfileFile>, ProfileFile> currentValuesBySupplier
132-
= Collections.synchronizedMap(new LinkedHashMap<>());
131+
final AtomicReference<Pair<Map<Supplier<ProfileFile>, ProfileFile>, ProfileFile>> state =
132+
new AtomicReference<>(Pair.of(Collections.emptyMap(), ProfileFile.empty()));
133133

134134
@Override
135135
public ProfileFile get() {
136-
boolean refreshAggregate = false;
137-
for (ProfileFileSupplier supplier : suppliers) {
138-
if (didSuppliedValueChange(supplier)) {
139-
refreshAggregate = true;
136+
while(true) {
137+
Pair<Map<Supplier<ProfileFile>, ProfileFile>, ProfileFile> currentState = state.get();
138+
Map<Supplier<ProfileFile>, ProfileFile> nextValues = new LinkedHashMap<>(currentState.left());
139+
140+
boolean refreshAggregate = false;
141+
142+
for (ProfileFileSupplier supplier : suppliers) {
143+
ProfileFile next = supplier.get();
144+
ProfileFile prev = nextValues.put(supplier, next);
145+
// we ONLY care about if the reference has changed, we don't care about object equality here
146+
if (prev != next) {
147+
refreshAggregate = true;
148+
}
140149
}
141-
}
142-
143-
if (refreshAggregate) {
144-
refreshCurrentAggregate();
145-
}
146-
147-
return currentAggregateProfileFile.get();
148-
}
149150

150-
private boolean didSuppliedValueChange(Supplier<ProfileFile> supplier) {
151-
ProfileFile next = supplier.get();
152-
ProfileFile current = currentValuesBySupplier.put(supplier, next);
151+
if (!refreshAggregate) {
152+
return currentState.right();
153+
}
153154

154-
return !Objects.equals(next, current);
155-
}
155+
ProfileFile.Aggregator aggregator = ProfileFile.aggregator();
156+
nextValues.values().forEach(aggregator::addFile);
157+
ProfileFile nextAggregate = aggregator.build();
156158

157-
private void refreshCurrentAggregate() {
158-
ProfileFile.Aggregator aggregator = ProfileFile.aggregator();
159-
currentValuesBySupplier.values().forEach(aggregator::addFile);
160-
ProfileFile current = currentAggregateProfileFile.get();
161-
ProfileFile next = aggregator.build();
162-
if (!Objects.equals(current, next)) {
163-
currentAggregateProfileFile.compareAndSet(current, next);
159+
Pair<Map<Supplier<ProfileFile>, ProfileFile>, ProfileFile> nextState = Pair.of(nextValues, nextAggregate);
160+
if (state.compareAndSet(currentState, nextState)) {
161+
return nextAggregate;
162+
}
163+
// else: another thread has modified the state in between, retry with the fresh state
164164
}
165165
}
166-
167166
};
168167
}
169168

core/profiles/src/test/java/software/amazon/awssdk/profiles/ProfileFileSupplierTest.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,20 @@
3131
import java.time.ZoneId;
3232
import java.time.ZoneOffset;
3333
import java.time.temporal.TemporalAmount;
34+
import java.util.ArrayList;
3435
import java.util.Arrays;
3536
import java.util.List;
3637
import java.util.Objects;
3738
import java.util.Optional;
3839
import java.util.Set;
40+
import java.util.concurrent.CompletableFuture;
3941
import java.util.concurrent.ConcurrentHashMap;
42+
import java.util.concurrent.CountDownLatch;
43+
import java.util.concurrent.ExecutionException;
44+
import java.util.concurrent.ExecutorService;
45+
import java.util.concurrent.Executors;
46+
import java.util.concurrent.Future;
47+
import java.util.concurrent.TimeUnit;
4048
import java.util.concurrent.atomic.AtomicInteger;
4149
import java.util.function.Predicate;
4250
import java.util.stream.Collectors;
@@ -503,6 +511,50 @@ void aggregate_duplicateOptionsGivenReloadingProfileFirst_preservesPrecedence()
503511
assertThat(accessKeyId).isEqualTo("defaultAccessKey2");
504512
}
505513

514+
@Test
515+
void aggregate_concurrentGetAlwaysReturnsCorrectAggregate() throws ExecutionException, InterruptedException {
516+
ProfileFile credentialFile = credentialProfileFile("test1", "key1", "secret1");
517+
ProfileFile configFile = configProfileFile("profile test",
518+
Pair.of("region", "us-west-2"),
519+
Pair.of("aws_account_id", "012354678922"));
520+
521+
522+
ProfileFile expectedAggregate = ProfileFile.aggregator().addFile(credentialFile).addFile(configFile).build();
523+
524+
ProfileFileSupplier supplier = ProfileFileSupplier.aggregate(() -> credentialFile, () -> configFile);
525+
526+
ExecutorService executor = Executors.newFixedThreadPool(24);
527+
CountDownLatch startLatch = new CountDownLatch(1);
528+
List<Future<Boolean>> tasks = new ArrayList<>();
529+
530+
for(int i = 0; i < 24; i++) {
531+
tasks.add(executor.submit(() -> {
532+
try {
533+
startLatch.await();
534+
ProfileFile resolved = supplier.get();
535+
return Objects.equals(expectedAggregate, resolved);
536+
} catch (InterruptedException e) {
537+
throw new RuntimeException(e);
538+
}
539+
}));
540+
}
541+
// All tasks are now submitted — release them
542+
startLatch.countDown();
543+
executor.shutdown();
544+
try {
545+
assertThat(executor.awaitTermination(10, TimeUnit.SECONDS))
546+
.as("executor did not terminate")
547+
.isTrue();
548+
} finally {
549+
executor.shutdownNow();
550+
}
551+
552+
// assert that all concurrent get's returned the same, expected aggregate
553+
for(Future<Boolean> task : tasks) {
554+
assertThat(task.get()).isTrue();
555+
}
556+
}
557+
506558
@Test
507559
void fixedProfileFile_nullProfileFile_returnsNonNullSupplier() {
508560
ProfileFile file = null;

0 commit comments

Comments
 (0)