Skip to content

Commit 0515db4

Browse files
Frank-Gu-81QuChen88
authored andcommitted
Caching - Add IAM authentication support for ElastiCache Valkey (PR #4)
- Add ElastiCacheIamTokenUtility for token generation - Extend DataRemoteCachePlugin with IAM auth detection - Support serverless and regular ElastiCache endpoints - Implement 15-minute token refresh and 12-hour re-auth cycles - Add cacheIamAuthEnabled configuration property
1 parent 76483ca commit 0515db4

File tree

7 files changed

+980
-19
lines changed

7 files changed

+980
-19
lines changed

examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ public class DatabaseConnectionWithCacheExample {
1212
private static final String DB_CONNECTION_STRING = env.get("DB_CONNECTION_STRING");
1313
private static final String CACHE_RW_SERVER_ADDR = env.get("CACHE_RW_SERVER_ADDR");
1414
private static final String CACHE_RO_SERVER_ADDR = env.get("CACHE_RO_SERVER_ADDR");
15+
// If the cache server is authenticated with IAM
16+
private static final String CACHE_NAME = env.get("CACHE_NAME");
17+
// Both IAM and traditional auth uses the same CACHE_USERNAME
18+
private static final String CACHE_USERNAME = env.get("CACHE_USERNAME"); // e.g., "iam-user-01" / "username"
19+
private static final String CACHE_IAM_REGION = env.get("CACHE_IAM_REGION"); // e.g., "us-west-2"
20+
// If the cache server is authenticated with traditional username password
21+
// private static final String CACHE_PASSWORD = env.get("CACHE_PASSWORD");
1522
private static final String USERNAME = env.get("DB_USERNAME");
1623
private static final String PASSWORD = env.get("DB_PASSWORD");
1724
private static final String USE_SSL = env.get("USE_SSL");
@@ -30,6 +37,12 @@ public static void main(String[] args) throws SQLException {
3037
properties.setProperty("wrapperPlugins", "dataRemoteCache");
3138
properties.setProperty("cacheEndpointAddrRw", CACHE_RW_SERVER_ADDR);
3239
properties.setProperty("cacheEndpointAddrRo", CACHE_RO_SERVER_ADDR);
40+
// If the cache server is authenticated with IAM
41+
properties.setProperty("cacheName", CACHE_NAME);
42+
properties.setProperty("cacheUsername", CACHE_USERNAME);
43+
properties.setProperty("cacheIamRegion", CACHE_IAM_REGION);
44+
// If the cache server is authenticated with traditional username password
45+
// properties.setProperty("cachePassword", PASSWORD);
3346
properties.setProperty("cacheUseSSL", USE_SSL); // "true" or "false"
3447
properties.setProperty("wrapperLogUnclosedConnections", "true");
3548
String queryStr = "/*+ CACHE_PARAM(ttl=300s) */ select * from cinemas";

wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java

Lines changed: 150 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,52 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
117
package software.amazon.jdbc.plugin.cache;
218

319
import io.lettuce.core.RedisClient;
4-
import io.lettuce.core.RedisCommandExecutionException;
20+
import io.lettuce.core.RedisCredentials;
21+
import io.lettuce.core.RedisCredentialsProvider;
522
import io.lettuce.core.RedisURI;
23+
import io.lettuce.core.RedisCommandExecutionException;
624
import io.lettuce.core.SetArgs;
725
import io.lettuce.core.api.StatefulRedisConnection;
826
import io.lettuce.core.api.async.RedisAsyncCommands;
927
import io.lettuce.core.codec.ByteArrayCodec;
1028
import io.lettuce.core.resource.ClientResources;
11-
import software.amazon.jdbc.AwsWrapperProperty;
29+
import reactor.core.publisher.Mono;
30+
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
31+
import software.amazon.awssdk.regions.Region;
1232
import java.nio.charset.StandardCharsets;
1333
import java.security.MessageDigest;
1434
import java.security.NoSuchAlgorithmException;
1535
import java.time.Duration;
1636
import java.util.Properties;
37+
import java.util.concurrent.TimeUnit;
1738
import java.util.concurrent.locks.ReentrantLock;
39+
import java.util.function.Supplier;
1840
import java.util.logging.Logger;
1941
import org.apache.commons.pool2.BasePooledObjectFactory;
2042
import org.apache.commons.pool2.impl.GenericObjectPool;
2143
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
2244
import org.apache.commons.pool2.impl.DefaultPooledObject;
2345
import org.apache.commons.pool2.PooledObject;
46+
import software.amazon.jdbc.AwsWrapperProperty;
2447
import software.amazon.jdbc.PropertyDefinition;
48+
import software.amazon.jdbc.authentication.AwsCredentialsManager;
49+
import software.amazon.jdbc.plugin.iam.ElastiCacheIamTokenUtility;
2550
import software.amazon.jdbc.util.StringUtils;
2651

2752
// Abstraction layer on top of a connection to a remote cache server
@@ -31,18 +56,21 @@ public class CacheConnection {
3156
private static volatile GenericObjectPool<StatefulRedisConnection<byte[], byte[]>> readConnectionPool;
3257
private static volatile GenericObjectPool<StatefulRedisConnection<byte[], byte[]>> writeConnectionPool;
3358
private static final GenericObjectPoolConfig<StatefulRedisConnection<byte[], byte[]>> poolConfig = createPoolConfig();
34-
private final String cacheRwServerAddr; // read-write cache server
35-
private final String cacheRoServerAddr; // read-only cache server
36-
private MessageDigest msgHashDigest = null;
3759

3860
private static final int DEFAULT_POOL_SIZE = 20;
3961
private static final int DEFAULT_POOL_MAX_IDLE = 20;
4062
private static final int DEFAULT_POOL_MIN_IDLE = 0;
4163
private static final long DEFAULT_MAX_BORROW_WAIT_MS = 100;
64+
private static final long TOKEN_CACHE_DURATION = 15 * 60 - 30;
4265

4366
private static final ReentrantLock READ_LOCK = new ReentrantLock();
4467
private static final ReentrantLock WRITE_LOCK = new ReentrantLock();
4568

69+
private final String cacheRwServerAddr; // read-write cache server
70+
private final String cacheRoServerAddr; // read-only cache server
71+
private final String[] defaultCacheServerHostAndPort;
72+
private MessageDigest msgHashDigest = null;
73+
4674
protected static final AwsWrapperProperty CACHE_RW_ENDPOINT_ADDR =
4775
new AwsWrapperProperty(
4876
"cacheEndpointAddrRw",
@@ -61,7 +89,38 @@ public class CacheConnection {
6189
"true",
6290
"Whether to use SSL for cache connections.");
6391

92+
protected static final AwsWrapperProperty CACHE_IAM_REGION =
93+
new AwsWrapperProperty(
94+
"cacheIamRegion",
95+
null,
96+
"AWS region for ElastiCache IAM authentication.");
97+
98+
protected static final AwsWrapperProperty CACHE_USERNAME =
99+
new AwsWrapperProperty(
100+
"cacheUsername",
101+
null,
102+
"Username for ElastiCache regular authentication.");
103+
104+
protected static final AwsWrapperProperty CACHE_PASSWORD =
105+
new AwsWrapperProperty(
106+
"cachePassword",
107+
null,
108+
"Password for ElastiCache regular authentication.");
109+
110+
protected static final AwsWrapperProperty CACHE_NAME =
111+
new AwsWrapperProperty(
112+
"cacheName",
113+
null,
114+
"Explicit cache name for ElastiCache IAM authentication. ");
115+
64116
private final boolean useSSL;
117+
private final boolean iamAuthEnabled;
118+
private final String cacheIamRegion;
119+
private final String cacheUsername;
120+
private final String cacheName;
121+
private final String cachePassword;
122+
private final Properties awsProfileProperties;
123+
private final AwsCredentialsProvider credentialsProvider;
65124

66125
static {
67126
PropertyDefinition.registerPluginProperties(CacheConnection.class);
@@ -71,6 +130,44 @@ public CacheConnection(final Properties properties) {
71130
this.cacheRwServerAddr = CACHE_RW_ENDPOINT_ADDR.getString(properties);
72131
this.cacheRoServerAddr = CACHE_RO_ENDPOINT_ADDR.getString(properties);
73132
this.useSSL = Boolean.parseBoolean(CACHE_USE_SSL.getString(properties));
133+
this.cacheName = CACHE_NAME.getString(properties);
134+
this.cacheIamRegion = CACHE_IAM_REGION.getString(properties);
135+
this.cacheUsername = CACHE_USERNAME.getString(properties);
136+
this.cachePassword = CACHE_PASSWORD.getString(properties);
137+
this.iamAuthEnabled = !StringUtils.isNullOrEmpty(this.cacheIamRegion);
138+
boolean hasTraditionalAuth = !StringUtils.isNullOrEmpty(this.cachePassword);
139+
// Validate authentication configuration
140+
if (this.iamAuthEnabled && hasTraditionalAuth) {
141+
throw new IllegalArgumentException(
142+
"Cannot specify both IAM authentication (cacheIamRegion) and traditional authentication (cachePassword). Choose one authentication method.");
143+
}
144+
if (this.cacheRwServerAddr == null) {
145+
throw new IllegalArgumentException("Cache endpoint address is required");
146+
}
147+
this.defaultCacheServerHostAndPort = getHostnameAndPort(this.cacheRwServerAddr);
148+
if (this.iamAuthEnabled) {
149+
if (this.cacheUsername == null || this.defaultCacheServerHostAndPort[0] == null || this.cacheName == null) {
150+
throw new IllegalArgumentException("IAM authentication requires cache name, username, region, and hostname");
151+
}
152+
}
153+
if (PropertyDefinition.AWS_PROFILE.getString(properties) != null) {
154+
this.awsProfileProperties = new Properties();
155+
this.awsProfileProperties.setProperty(
156+
PropertyDefinition.AWS_PROFILE.name,
157+
PropertyDefinition.AWS_PROFILE.getString(properties)
158+
);
159+
} else {
160+
this.awsProfileProperties = null;
161+
}
162+
if (this.iamAuthEnabled) {
163+
// Handle null case
164+
Properties propsToPass = (this.awsProfileProperties != null)
165+
? this.awsProfileProperties
166+
: new Properties();
167+
this.credentialsProvider = AwsCredentialsManager.getProvider(null, propsToPass);
168+
} else {
169+
this.credentialsProvider = null;
170+
}
74171
}
75172

76173
/* Here we check if we need to initialise connection pool for read or write to cache.
@@ -113,15 +210,14 @@ private void createConnectionPool(boolean isRead) {
113210
if (isRead && !StringUtils.isNullOrEmpty(cacheRoServerAddr)) {
114211
serverAddr = cacheRoServerAddr;
115212
}
116-
String[] hostnameAndPort = serverAddr.split(":");
117-
RedisURI redisUriCluster = RedisURI.Builder.redis(hostnameAndPort[0])
118-
.withPort(Integer.parseInt(hostnameAndPort[1]))
119-
.withSsl(useSSL).withVerifyPeer(false).withLibraryName("aws-sql-jdbc-lettuce").build();
213+
String[] hostnameAndPort = getHostnameAndPort(serverAddr);
214+
RedisURI redisUriCluster = buildRedisURI(hostnameAndPort[0], Integer.parseInt(hostnameAndPort[1]));
120215

121216
RedisClient client = RedisClient.create(resources, redisUriCluster);
122217
GenericObjectPool<StatefulRedisConnection<byte[], byte[]>> pool = new GenericObjectPool<>(
123218
new BasePooledObjectFactory<StatefulRedisConnection<byte[], byte[]>>() {
124219
public StatefulRedisConnection<byte[], byte[]> create() {
220+
125221
StatefulRedisConnection<byte[], byte[]> connection = client.connect(new ByteArrayCodec());
126222
// In cluster mode, we need to send READONLY command to the server for reading from replica.
127223
// Note: we gracefully ignore ERR reply to support non cluster mode.
@@ -148,7 +244,6 @@ public PooledObject<StatefulRedisConnection<byte[], byte[]>> wrap(StatefulRedisC
148244
} else {
149245
writeConnectionPool = pool;
150246
}
151-
152247
} catch (Exception e) {
153248
String poolType = isRead ? "read" : "write";
154249
String errorMsg = String.format("Failed to create Cache %s connection pool", poolType);
@@ -247,13 +342,13 @@ public void writeToCache(String key, byte[] value, int expiry) {
247342
private void returnConnectionBackToPool(StatefulRedisConnection <byte[], byte[]> connection, boolean isConnectionBroken, boolean isRead) {
248343
GenericObjectPool<StatefulRedisConnection<byte[], byte[]>> pool = isRead ? readConnectionPool : writeConnectionPool;
249344
if (isConnectionBroken) {
250-
try {
251-
pool.invalidateObject(connection);
252-
} catch (Exception e) {
253-
throw new RuntimeException("Could not invalidate connection for the pool", e);
254-
}
345+
try {
346+
pool.invalidateObject(connection);
347+
} catch (Exception e) {
348+
throw new RuntimeException("Could not invalidate connection for the pool", e);
349+
}
255350
} else {
256-
pool.returnObject(connection);
351+
pool.returnObject(connection);
257352
}
258353
}
259354

@@ -263,4 +358,43 @@ protected void setConnectionPools(GenericObjectPool<StatefulRedisConnection<byte
263358
readConnectionPool = readPool;
264359
writeConnectionPool = writePool;
265360
}
361+
362+
protected RedisURI buildRedisURI(String hostname, int port) {
363+
RedisURI.Builder uriBuilder = RedisURI.Builder.redis(hostname)
364+
.withPort(port)
365+
.withSsl(useSSL)
366+
.withVerifyPeer(false)
367+
.withLibraryName("aws-sql-jdbc-lettuce");
368+
369+
if (this.iamAuthEnabled) {
370+
// Create a credentials provider that Lettuce will call whenever authentication is needed
371+
RedisCredentialsProvider credentialsProvider = () -> {
372+
// Create a cached token supplier that automatically refreshes tokens every 14.5 minutes
373+
Supplier<String> tokenSupplier = CachedSupplier.memoizeWithExpiration(
374+
() -> {
375+
ElastiCacheIamTokenUtility tokenUtility = new ElastiCacheIamTokenUtility(this.cacheName);
376+
return tokenUtility.generateAuthenticationToken(
377+
this.credentialsProvider,
378+
Region.of(this.cacheIamRegion),
379+
this.defaultCacheServerHostAndPort[0],
380+
Integer.parseInt(this.defaultCacheServerHostAndPort[1]),
381+
this.cacheUsername
382+
);
383+
},
384+
TOKEN_CACHE_DURATION,
385+
TimeUnit.SECONDS
386+
);
387+
// Package the username and token (from cache or freshly generated) into Redis credentials
388+
return Mono.just(RedisCredentials.just(this.cacheUsername, tokenSupplier.get()));
389+
};
390+
uriBuilder.withAuthentication(credentialsProvider);
391+
} else if (!StringUtils.isNullOrEmpty(this.cachePassword)) {
392+
uriBuilder.withAuthentication(this.cacheUsername, this.cachePassword);
393+
}
394+
return uriBuilder.build();
395+
}
396+
397+
private String[] getHostnameAndPort(String serverAddr) {
398+
return serverAddr.split(":");
399+
}
266400
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package software.amazon.jdbc.plugin.cache;
18+
19+
import java.util.Objects;
20+
import java.util.concurrent.TimeUnit;
21+
import java.util.concurrent.locks.ReentrantLock;
22+
import java.util.function.Supplier;
23+
24+
public final class CachedSupplier {
25+
26+
private CachedSupplier() {
27+
throw new UnsupportedOperationException("Utility class should not be instantiated");
28+
}
29+
30+
public static <T> Supplier<T> memoizeWithExpiration(
31+
Supplier<T> delegate, long duration, TimeUnit unit) {
32+
33+
Objects.requireNonNull(delegate, "delegate Supplier must not be null");
34+
Objects.requireNonNull(unit, "TimeUnit must not be null");
35+
if (duration <= 0) {
36+
throw new IllegalArgumentException("duration must be > 0");
37+
}
38+
39+
return new ExpiringMemoizingSupplier<>(delegate, duration, unit);
40+
}
41+
42+
private static final class ExpiringMemoizingSupplier<T> implements Supplier<T> {
43+
44+
private final Supplier<T> delegate;
45+
private final long durationNanos;
46+
private final ReentrantLock lock = new ReentrantLock();
47+
48+
private volatile T value;
49+
private volatile long expirationNanos; // 0 means not yet initialized
50+
51+
ExpiringMemoizingSupplier(Supplier<T> delegate, long duration, TimeUnit unit) {
52+
this.delegate = delegate;
53+
this.durationNanos = unit.toNanos(duration);
54+
}
55+
56+
@Override
57+
public T get() {
58+
long now = System.nanoTime();
59+
60+
// Check if value is expired or uninitialized
61+
if (expirationNanos == 0 || now - expirationNanos >= 0) {
62+
lock.lock();
63+
try {
64+
if (expirationNanos == 0 || now - expirationNanos >= 0) {
65+
value = delegate.get();
66+
long next = now + durationNanos;
67+
expirationNanos = (next == 0) ? 1 : next; // avoid 0 sentinel
68+
}
69+
} finally {
70+
lock.unlock();
71+
}
72+
}
73+
return value;
74+
}
75+
}
76+
}

0 commit comments

Comments
 (0)