Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/bugfix-AWSSDKforJavav2-0caa9ea.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "bugfix",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Fix a race condition in aggregate ProfileFileSupplier that could cause credential resolution failures with shared DefaultCredentialsProvider."
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,10 @@

import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import software.amazon.awssdk.annotations.SdkPublicApi;
import software.amazon.awssdk.profiles.internal.AggregateProfileFileSupplier;
import software.amazon.awssdk.profiles.internal.ProfileFileRefresher;

/**
Expand Down Expand Up @@ -125,46 +121,7 @@ static ProfileFileSupplier fixedProfileFile(ProfileFile profileFile) {
*/
static ProfileFileSupplier aggregate(ProfileFileSupplier... suppliers) {

return new ProfileFileSupplier() {

final AtomicReference<ProfileFile> currentAggregateProfileFile = new AtomicReference<>();
final Map<Supplier<ProfileFile>, ProfileFile> currentValuesBySupplier
= Collections.synchronizedMap(new LinkedHashMap<>());

@Override
public ProfileFile get() {
boolean refreshAggregate = false;
for (ProfileFileSupplier supplier : suppliers) {
if (didSuppliedValueChange(supplier)) {
refreshAggregate = true;
}
}

if (refreshAggregate) {
refreshCurrentAggregate();
}

return currentAggregateProfileFile.get();
}

private boolean didSuppliedValueChange(Supplier<ProfileFile> supplier) {
ProfileFile next = supplier.get();
ProfileFile current = currentValuesBySupplier.put(supplier, next);

return !Objects.equals(next, current);
}

private void refreshCurrentAggregate() {
ProfileFile.Aggregator aggregator = ProfileFile.aggregator();
currentValuesBySupplier.values().forEach(aggregator::addFile);
ProfileFile current = currentAggregateProfileFile.get();
ProfileFile next = aggregator.build();
if (!Objects.equals(current, next)) {
currentAggregateProfileFile.compareAndSet(current, next);
}
}

};
return new AggregateProfileFileSupplier(suppliers);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.profiles.internal;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.profiles.ProfileFile;
import software.amazon.awssdk.profiles.ProfileFileSupplier;

/**
* A {@link ProfileFileSupplier} that combines the {@link ProfileFile} objects from multiple
* {@code ProfileFileSupplier}s. Objects are passed into {@link ProfileFile.Aggregator}.
*/
@SdkInternalApi
public class AggregateProfileFileSupplier implements ProfileFileSupplier {
private final List<ProfileFileSupplier> suppliers;

// supplier values and the resulting aggregate must always be updated atomically together
private final AtomicReference<SupplierState> state =
new AtomicReference<>(new SupplierState(Collections.emptyMap(), null));

public AggregateProfileFileSupplier(ProfileFileSupplier... suppliers) {
this.suppliers = Collections.unmodifiableList(Arrays.asList(suppliers));
}

@Override
public ProfileFile get() {
SupplierState currentState = state.get();
Map<Supplier<ProfileFile>, ProfileFile> currentValues = currentState.values;
Map<Supplier<ProfileFile>, ProfileFile> changedValues = changedSupplierValues(currentValues);

if (changedValues == null) {
// no suppliers have changed values, return the current aggregate
return currentState.aggregate;
}

// one or more supplier values have changed, we need to update the aggregate (and the state)
// the order of the suppliers matters so we MUST preserve it using LinkedHashMap with insertion ordering
Map<Supplier<ProfileFile>, ProfileFile> nextValues = new LinkedHashMap<>(currentValues);
nextValues.putAll(changedValues);

ProfileFile.Aggregator aggregator = ProfileFile.aggregator();
nextValues.values().forEach(aggregator::addFile);
ProfileFile nextAggregate = aggregator.build();

SupplierState nextState = new SupplierState(nextValues, nextAggregate);
if (state.compareAndSet(currentState, nextState)) {
return nextAggregate;
}
// else: another thread has modified the state in between, assume it is up to date and use the new state
return state.get().aggregate;
}

// return the suppliers with changed values. Returns null if no values have changed
private Map<Supplier<ProfileFile>, ProfileFile> changedSupplierValues(Map<Supplier<ProfileFile>, ProfileFile> currentValues) {
Map<Supplier<ProfileFile>, ProfileFile> changedValues = null;
for (ProfileFileSupplier supplier : suppliers) {
ProfileFile next = supplier.get();
ProfileFile prev = currentValues.get(supplier);
// we ONLY care about if the reference has changed, we don't care about object equality here
if (prev != next) {
if (changedValues == null) {
// changed values must also preserve supplier order
changedValues = new LinkedHashMap<>();
}
changedValues.put(supplier, next);
}
}
return changedValues;
}

/**
* Supplier values and the resulting aggregate must always be updated atomically together.
* This record class tracks all mutable elements of the supplier's state together.
*/
private static final class SupplierState {
private final Map<Supplier<ProfileFile>, ProfileFile> values;
private final ProfileFile aggregate;

private SupplierState(Map<Supplier<ProfileFile>, ProfileFile> values, ProfileFile aggregate) {
this.values = values;
this.aggregate = aggregate;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,20 @@
import java.time.ZoneId;
import java.time.ZoneOffset;
import java.time.temporal.TemporalAmount;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -503,6 +511,50 @@ void aggregate_duplicateOptionsGivenReloadingProfileFirst_preservesPrecedence()
assertThat(accessKeyId).isEqualTo("defaultAccessKey2");
}

@Test
void aggregate_concurrentGetAlwaysReturnsCorrectAggregate() throws ExecutionException, InterruptedException {
ProfileFile credentialFile = credentialProfileFile("test1", "key1", "secret1");
ProfileFile configFile = configProfileFile("profile test",
Pair.of("region", "us-west-2"),
Pair.of("aws_account_id", "012354678922"));


ProfileFile expectedAggregate = ProfileFile.aggregator().addFile(credentialFile).addFile(configFile).build();

ProfileFileSupplier supplier = ProfileFileSupplier.aggregate(() -> credentialFile, () -> configFile);

ExecutorService executor = Executors.newFixedThreadPool(24);
CountDownLatch startLatch = new CountDownLatch(1);
List<Future<Boolean>> tasks = new ArrayList<>();

for(int i = 0; i < 24; i++) {
tasks.add(executor.submit(() -> {
try {
startLatch.await();
ProfileFile resolved = supplier.get();
return Objects.equals(expectedAggregate, resolved);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}));
}
// All tasks are now submitted — release them
startLatch.countDown();
executor.shutdown();
try {
assertThat(executor.awaitTermination(10, TimeUnit.SECONDS))
.as("executor did not terminate")
.isTrue();
} finally {
executor.shutdownNow();
}

// assert that all concurrent get's returned the same, expected aggregate
for(Future<Boolean> task : tasks) {
assertThat(task.get()).isTrue();
}
}

@Test
void fixedProfileFile_nullProfileFile_returnsNonNullSupplier() {
ProfileFile file = null;
Expand Down
Loading