diff --git a/.changes/next-release/bugfix-AWSSDKforJavav2-0caa9ea.json b/.changes/next-release/bugfix-AWSSDKforJavav2-0caa9ea.json new file mode 100644 index 00000000000..90582c7dd34 --- /dev/null +++ b/.changes/next-release/bugfix-AWSSDKforJavav2-0caa9ea.json @@ -0,0 +1,6 @@ +{ + "type": "bugfix", + "category": "AWS SDK for Java v2", + "contributor": "", + "description": "Fix a race condition in aggregate ProfileFileSupplier that could cause credential resolution failures with shared DefaultCredentialsProvider." +} diff --git a/core/profiles/src/main/java/software/amazon/awssdk/profiles/ProfileFileSupplier.java b/core/profiles/src/main/java/software/amazon/awssdk/profiles/ProfileFileSupplier.java index 4ea4cd34989..e61935a7f7a 100644 --- a/core/profiles/src/main/java/software/amazon/awssdk/profiles/ProfileFileSupplier.java +++ b/core/profiles/src/main/java/software/amazon/awssdk/profiles/ProfileFileSupplier.java @@ -17,14 +17,10 @@ import java.nio.file.Files; import java.nio.file.Path; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.Objects; import java.util.Optional; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import software.amazon.awssdk.annotations.SdkPublicApi; +import software.amazon.awssdk.profiles.internal.AggregateProfileFileSupplier; import software.amazon.awssdk.profiles.internal.ProfileFileRefresher; /** @@ -125,46 +121,7 @@ static ProfileFileSupplier fixedProfileFile(ProfileFile profileFile) { */ static ProfileFileSupplier aggregate(ProfileFileSupplier... suppliers) { - return new ProfileFileSupplier() { - - final AtomicReference currentAggregateProfileFile = new AtomicReference<>(); - final Map, ProfileFile> currentValuesBySupplier - = Collections.synchronizedMap(new LinkedHashMap<>()); - - @Override - public ProfileFile get() { - boolean refreshAggregate = false; - for (ProfileFileSupplier supplier : suppliers) { - if (didSuppliedValueChange(supplier)) { - refreshAggregate = true; - } - } - - if (refreshAggregate) { - refreshCurrentAggregate(); - } - - return currentAggregateProfileFile.get(); - } - - private boolean didSuppliedValueChange(Supplier supplier) { - ProfileFile next = supplier.get(); - ProfileFile current = currentValuesBySupplier.put(supplier, next); - - return !Objects.equals(next, current); - } - - private void refreshCurrentAggregate() { - ProfileFile.Aggregator aggregator = ProfileFile.aggregator(); - currentValuesBySupplier.values().forEach(aggregator::addFile); - ProfileFile current = currentAggregateProfileFile.get(); - ProfileFile next = aggregator.build(); - if (!Objects.equals(current, next)) { - currentAggregateProfileFile.compareAndSet(current, next); - } - } - - }; + return new AggregateProfileFileSupplier(suppliers); } } diff --git a/core/profiles/src/main/java/software/amazon/awssdk/profiles/internal/AggregateProfileFileSupplier.java b/core/profiles/src/main/java/software/amazon/awssdk/profiles/internal/AggregateProfileFileSupplier.java new file mode 100644 index 00000000000..b371ac35caf --- /dev/null +++ b/core/profiles/src/main/java/software/amazon/awssdk/profiles/internal/AggregateProfileFileSupplier.java @@ -0,0 +1,104 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.profiles.internal; + +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.profiles.ProfileFile; +import software.amazon.awssdk.profiles.ProfileFileSupplier; + +/** + * A {@link ProfileFileSupplier} that combines the {@link ProfileFile} objects from multiple + * {@code ProfileFileSupplier}s. Objects are passed into {@link ProfileFile.Aggregator}. + */ +@SdkInternalApi +public class AggregateProfileFileSupplier implements ProfileFileSupplier { + private final List suppliers; + + // supplier values and the resulting aggregate must always be updated atomically together + private final AtomicReference state = + new AtomicReference<>(new SupplierState(Collections.emptyMap(), null)); + + public AggregateProfileFileSupplier(ProfileFileSupplier... suppliers) { + this.suppliers = Collections.unmodifiableList(Arrays.asList(suppliers)); + } + + @Override + public ProfileFile get() { + SupplierState currentState = state.get(); + Map, ProfileFile> currentValues = currentState.values; + Map, ProfileFile> changedValues = changedSupplierValues(currentValues); + + if (changedValues == null) { + // no suppliers have changed values, return the current aggregate + return currentState.aggregate; + } + + // one or more supplier values have changed, we need to update the aggregate (and the state) + // the order of the suppliers matters so we MUST preserve it using LinkedHashMap with insertion ordering + Map, ProfileFile> nextValues = new LinkedHashMap<>(currentValues); + nextValues.putAll(changedValues); + + ProfileFile.Aggregator aggregator = ProfileFile.aggregator(); + nextValues.values().forEach(aggregator::addFile); + ProfileFile nextAggregate = aggregator.build(); + + SupplierState nextState = new SupplierState(nextValues, nextAggregate); + if (state.compareAndSet(currentState, nextState)) { + return nextAggregate; + } + // else: another thread has modified the state in between, assume it is up to date and use the new state + return state.get().aggregate; + } + + // return the suppliers with changed values. Returns null if no values have changed + private Map, ProfileFile> changedSupplierValues(Map, ProfileFile> currentValues) { + Map, ProfileFile> changedValues = null; + for (ProfileFileSupplier supplier : suppliers) { + ProfileFile next = supplier.get(); + ProfileFile prev = currentValues.get(supplier); + // we ONLY care about if the reference has changed, we don't care about object equality here + if (prev != next) { + if (changedValues == null) { + // changed values must also preserve supplier order + changedValues = new LinkedHashMap<>(); + } + changedValues.put(supplier, next); + } + } + return changedValues; + } + + /** + * Supplier values and the resulting aggregate must always be updated atomically together. + * This record class tracks all mutable elements of the supplier's state together. + */ + private static final class SupplierState { + private final Map, ProfileFile> values; + private final ProfileFile aggregate; + + private SupplierState(Map, ProfileFile> values, ProfileFile aggregate) { + this.values = values; + this.aggregate = aggregate; + } + } +} diff --git a/core/profiles/src/test/java/software/amazon/awssdk/profiles/ProfileFileSupplierTest.java b/core/profiles/src/test/java/software/amazon/awssdk/profiles/ProfileFileSupplierTest.java index 5a725914aba..4f05ae26172 100644 --- a/core/profiles/src/test/java/software/amazon/awssdk/profiles/ProfileFileSupplierTest.java +++ b/core/profiles/src/test/java/software/amazon/awssdk/profiles/ProfileFileSupplierTest.java @@ -31,12 +31,20 @@ import java.time.ZoneId; import java.time.ZoneOffset; import java.time.temporal.TemporalAmount; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -503,6 +511,50 @@ void aggregate_duplicateOptionsGivenReloadingProfileFirst_preservesPrecedence() assertThat(accessKeyId).isEqualTo("defaultAccessKey2"); } + @Test + void aggregate_concurrentGetAlwaysReturnsCorrectAggregate() throws ExecutionException, InterruptedException { + ProfileFile credentialFile = credentialProfileFile("test1", "key1", "secret1"); + ProfileFile configFile = configProfileFile("profile test", + Pair.of("region", "us-west-2"), + Pair.of("aws_account_id", "012354678922")); + + + ProfileFile expectedAggregate = ProfileFile.aggregator().addFile(credentialFile).addFile(configFile).build(); + + ProfileFileSupplier supplier = ProfileFileSupplier.aggregate(() -> credentialFile, () -> configFile); + + ExecutorService executor = Executors.newFixedThreadPool(24); + CountDownLatch startLatch = new CountDownLatch(1); + List> tasks = new ArrayList<>(); + + for(int i = 0; i < 24; i++) { + tasks.add(executor.submit(() -> { + try { + startLatch.await(); + ProfileFile resolved = supplier.get(); + return Objects.equals(expectedAggregate, resolved); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + })); + } + // All tasks are now submitted — release them + startLatch.countDown(); + executor.shutdown(); + try { + assertThat(executor.awaitTermination(10, TimeUnit.SECONDS)) + .as("executor did not terminate") + .isTrue(); + } finally { + executor.shutdownNow(); + } + + // assert that all concurrent get's returned the same, expected aggregate + for(Future task : tasks) { + assertThat(task.get()).isTrue(); + } + } + @Test void fixedProfileFile_nullProfileFile_returnsNonNullSupplier() { ProfileFile file = null;