1+ /*
2+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+ *
4+ * Licensed under the Apache License, Version 2.0 (the "License").
5+ * You may not use this file except in compliance with the License.
6+ * You may obtain a copy of the License at
7+ *
8+ * http://www.apache.org/licenses/LICENSE-2.0
9+ *
10+ * Unless required by applicable law or agreed to in writing, software
11+ * distributed under the License is distributed on an "AS IS" BASIS,
12+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ * See the License for the specific language governing permissions and
14+ * limitations under the License.
15+ */
16+
117package software .amazon .jdbc .plugin .cache ;
218
319import io .lettuce .core .RedisClient ;
4- import io .lettuce .core .RedisCommandExecutionException ;
20+ import io .lettuce .core .RedisCredentials ;
21+ import io .lettuce .core .RedisCredentialsProvider ;
522import io .lettuce .core .RedisURI ;
23+ import io .lettuce .core .RedisCommandExecutionException ;
624import io .lettuce .core .SetArgs ;
725import io .lettuce .core .api .StatefulRedisConnection ;
826import io .lettuce .core .api .async .RedisAsyncCommands ;
927import io .lettuce .core .codec .ByteArrayCodec ;
1028import io .lettuce .core .resource .ClientResources ;
11- import software .amazon .jdbc .AwsWrapperProperty ;
29+ import reactor .core .publisher .Mono ;
30+ import software .amazon .awssdk .auth .credentials .AwsCredentialsProvider ;
31+ import software .amazon .awssdk .regions .Region ;
1232import java .nio .charset .StandardCharsets ;
1333import java .security .MessageDigest ;
1434import java .security .NoSuchAlgorithmException ;
1535import java .time .Duration ;
1636import java .util .Properties ;
37+ import java .util .concurrent .TimeUnit ;
1738import java .util .concurrent .locks .ReentrantLock ;
39+ import java .util .function .Supplier ;
1840import java .util .logging .Logger ;
1941import org .apache .commons .pool2 .BasePooledObjectFactory ;
2042import org .apache .commons .pool2 .impl .GenericObjectPool ;
2143import org .apache .commons .pool2 .impl .GenericObjectPoolConfig ;
2244import org .apache .commons .pool2 .impl .DefaultPooledObject ;
2345import org .apache .commons .pool2 .PooledObject ;
46+ import software .amazon .jdbc .AwsWrapperProperty ;
2447import software .amazon .jdbc .PropertyDefinition ;
48+ import software .amazon .jdbc .authentication .AwsCredentialsManager ;
49+ import software .amazon .jdbc .plugin .iam .ElastiCacheIamTokenUtility ;
2550import software .amazon .jdbc .util .StringUtils ;
2651
2752// Abstraction layer on top of a connection to a remote cache server
@@ -31,18 +56,21 @@ public class CacheConnection {
3156 private static volatile GenericObjectPool <StatefulRedisConnection <byte [], byte []>> readConnectionPool ;
3257 private static volatile GenericObjectPool <StatefulRedisConnection <byte [], byte []>> writeConnectionPool ;
3358 private static final GenericObjectPoolConfig <StatefulRedisConnection <byte [], byte []>> poolConfig = createPoolConfig ();
34- private final String cacheRwServerAddr ; // read-write cache server
35- private final String cacheRoServerAddr ; // read-only cache server
36- private MessageDigest msgHashDigest = null ;
3759
3860 private static final int DEFAULT_POOL_SIZE = 20 ;
3961 private static final int DEFAULT_POOL_MAX_IDLE = 20 ;
4062 private static final int DEFAULT_POOL_MIN_IDLE = 0 ;
4163 private static final long DEFAULT_MAX_BORROW_WAIT_MS = 100 ;
64+ private static final long TOKEN_CACHE_DURATION = 15 * 60 - 30 ;
4265
4366 private static final ReentrantLock READ_LOCK = new ReentrantLock ();
4467 private static final ReentrantLock WRITE_LOCK = new ReentrantLock ();
4568
69+ private final String cacheRwServerAddr ; // read-write cache server
70+ private final String cacheRoServerAddr ; // read-only cache server
71+ private final String [] defaultCacheServerHostAndPort ;
72+ private MessageDigest msgHashDigest = null ;
73+
4674 protected static final AwsWrapperProperty CACHE_RW_ENDPOINT_ADDR =
4775 new AwsWrapperProperty (
4876 "cacheEndpointAddrRw" ,
@@ -61,7 +89,38 @@ public class CacheConnection {
6189 "true" ,
6290 "Whether to use SSL for cache connections." );
6391
92+ protected static final AwsWrapperProperty CACHE_IAM_REGION =
93+ new AwsWrapperProperty (
94+ "cacheIamRegion" ,
95+ null ,
96+ "AWS region for ElastiCache IAM authentication." );
97+
98+ protected static final AwsWrapperProperty CACHE_USERNAME =
99+ new AwsWrapperProperty (
100+ "cacheUsername" ,
101+ null ,
102+ "Username for ElastiCache regular authentication." );
103+
104+ protected static final AwsWrapperProperty CACHE_PASSWORD =
105+ new AwsWrapperProperty (
106+ "cachePassword" ,
107+ null ,
108+ "Password for ElastiCache regular authentication." );
109+
110+ protected static final AwsWrapperProperty CACHE_NAME =
111+ new AwsWrapperProperty (
112+ "cacheName" ,
113+ null ,
114+ "Explicit cache name for ElastiCache IAM authentication. " );
115+
64116 private final boolean useSSL ;
117+ private final boolean iamAuthEnabled ;
118+ private final String cacheIamRegion ;
119+ private final String cacheUsername ;
120+ private final String cacheName ;
121+ private final String cachePassword ;
122+ private final Properties awsProfileProperties ;
123+ private final AwsCredentialsProvider credentialsProvider ;
65124
66125 static {
67126 PropertyDefinition .registerPluginProperties (CacheConnection .class );
@@ -71,6 +130,44 @@ public CacheConnection(final Properties properties) {
71130 this .cacheRwServerAddr = CACHE_RW_ENDPOINT_ADDR .getString (properties );
72131 this .cacheRoServerAddr = CACHE_RO_ENDPOINT_ADDR .getString (properties );
73132 this .useSSL = Boolean .parseBoolean (CACHE_USE_SSL .getString (properties ));
133+ this .cacheName = CACHE_NAME .getString (properties );
134+ this .cacheIamRegion = CACHE_IAM_REGION .getString (properties );
135+ this .cacheUsername = CACHE_USERNAME .getString (properties );
136+ this .cachePassword = CACHE_PASSWORD .getString (properties );
137+ this .iamAuthEnabled = !StringUtils .isNullOrEmpty (this .cacheIamRegion );
138+ boolean hasTraditionalAuth = !StringUtils .isNullOrEmpty (this .cachePassword );
139+ // Validate authentication configuration
140+ if (this .iamAuthEnabled && hasTraditionalAuth ) {
141+ throw new IllegalArgumentException (
142+ "Cannot specify both IAM authentication (cacheIamRegion) and traditional authentication (cachePassword). Choose one authentication method." );
143+ }
144+ if (this .cacheRwServerAddr == null ) {
145+ throw new IllegalArgumentException ("Cache endpoint address is required" );
146+ }
147+ this .defaultCacheServerHostAndPort = getHostnameAndPort (this .cacheRwServerAddr );
148+ if (this .iamAuthEnabled ) {
149+ if (this .cacheUsername == null || this .defaultCacheServerHostAndPort [0 ] == null || this .cacheName == null ) {
150+ throw new IllegalArgumentException ("IAM authentication requires cache name, username, region, and hostname" );
151+ }
152+ }
153+ if (PropertyDefinition .AWS_PROFILE .getString (properties ) != null ) {
154+ this .awsProfileProperties = new Properties ();
155+ this .awsProfileProperties .setProperty (
156+ PropertyDefinition .AWS_PROFILE .name ,
157+ PropertyDefinition .AWS_PROFILE .getString (properties )
158+ );
159+ } else {
160+ this .awsProfileProperties = null ;
161+ }
162+ if (this .iamAuthEnabled ) {
163+ // Handle null case
164+ Properties propsToPass = (this .awsProfileProperties != null )
165+ ? this .awsProfileProperties
166+ : new Properties ();
167+ this .credentialsProvider = AwsCredentialsManager .getProvider (null , propsToPass );
168+ } else {
169+ this .credentialsProvider = null ;
170+ }
74171 }
75172
76173 /* Here we check if we need to initialise connection pool for read or write to cache.
@@ -113,15 +210,14 @@ private void createConnectionPool(boolean isRead) {
113210 if (isRead && !StringUtils .isNullOrEmpty (cacheRoServerAddr )) {
114211 serverAddr = cacheRoServerAddr ;
115212 }
116- String [] hostnameAndPort = serverAddr .split (":" );
117- RedisURI redisUriCluster = RedisURI .Builder .redis (hostnameAndPort [0 ])
118- .withPort (Integer .parseInt (hostnameAndPort [1 ]))
119- .withSsl (useSSL ).withVerifyPeer (false ).withLibraryName ("aws-sql-jdbc-lettuce" ).build ();
213+ String [] hostnameAndPort = getHostnameAndPort (serverAddr );
214+ RedisURI redisUriCluster = buildRedisURI (hostnameAndPort [0 ], Integer .parseInt (hostnameAndPort [1 ]));
120215
121216 RedisClient client = RedisClient .create (resources , redisUriCluster );
122217 GenericObjectPool <StatefulRedisConnection <byte [], byte []>> pool = new GenericObjectPool <>(
123218 new BasePooledObjectFactory <StatefulRedisConnection <byte [], byte []>>() {
124219 public StatefulRedisConnection <byte [], byte []> create () {
220+
125221 StatefulRedisConnection <byte [], byte []> connection = client .connect (new ByteArrayCodec ());
126222 // In cluster mode, we need to send READONLY command to the server for reading from replica.
127223 // Note: we gracefully ignore ERR reply to support non cluster mode.
@@ -148,7 +244,6 @@ public PooledObject<StatefulRedisConnection<byte[], byte[]>> wrap(StatefulRedisC
148244 } else {
149245 writeConnectionPool = pool ;
150246 }
151-
152247 } catch (Exception e ) {
153248 String poolType = isRead ? "read" : "write" ;
154249 String errorMsg = String .format ("Failed to create Cache %s connection pool" , poolType );
@@ -247,13 +342,13 @@ public void writeToCache(String key, byte[] value, int expiry) {
247342 private void returnConnectionBackToPool (StatefulRedisConnection <byte [], byte []> connection , boolean isConnectionBroken , boolean isRead ) {
248343 GenericObjectPool <StatefulRedisConnection <byte [], byte []>> pool = isRead ? readConnectionPool : writeConnectionPool ;
249344 if (isConnectionBroken ) {
250- try {
251- pool .invalidateObject (connection );
252- } catch (Exception e ) {
253- throw new RuntimeException ("Could not invalidate connection for the pool" , e );
254- }
345+ try {
346+ pool .invalidateObject (connection );
347+ } catch (Exception e ) {
348+ throw new RuntimeException ("Could not invalidate connection for the pool" , e );
349+ }
255350 } else {
256- pool .returnObject (connection );
351+ pool .returnObject (connection );
257352 }
258353 }
259354
@@ -263,4 +358,43 @@ protected void setConnectionPools(GenericObjectPool<StatefulRedisConnection<byte
263358 readConnectionPool = readPool ;
264359 writeConnectionPool = writePool ;
265360 }
361+
362+ protected RedisURI buildRedisURI (String hostname , int port ) {
363+ RedisURI .Builder uriBuilder = RedisURI .Builder .redis (hostname )
364+ .withPort (port )
365+ .withSsl (useSSL )
366+ .withVerifyPeer (false )
367+ .withLibraryName ("aws-sql-jdbc-lettuce" );
368+
369+ if (this .iamAuthEnabled ) {
370+ // Create a credentials provider that Lettuce will call whenever authentication is needed
371+ RedisCredentialsProvider credentialsProvider = () -> {
372+ // Create a cached token supplier that automatically refreshes tokens every 14.5 minutes
373+ Supplier <String > tokenSupplier = CachedSupplier .memoizeWithExpiration (
374+ () -> {
375+ ElastiCacheIamTokenUtility tokenUtility = new ElastiCacheIamTokenUtility (this .cacheName );
376+ return tokenUtility .generateAuthenticationToken (
377+ this .credentialsProvider ,
378+ Region .of (this .cacheIamRegion ),
379+ this .defaultCacheServerHostAndPort [0 ],
380+ Integer .parseInt (this .defaultCacheServerHostAndPort [1 ]),
381+ this .cacheUsername
382+ );
383+ },
384+ TOKEN_CACHE_DURATION ,
385+ TimeUnit .SECONDS
386+ );
387+ // Package the username and token (from cache or freshly generated) into Redis credentials
388+ return Mono .just (RedisCredentials .just (this .cacheUsername , tokenSupplier .get ()));
389+ };
390+ uriBuilder .withAuthentication (credentialsProvider );
391+ } else if (!StringUtils .isNullOrEmpty (this .cachePassword )) {
392+ uriBuilder .withAuthentication (this .cacheUsername , this .cachePassword );
393+ }
394+ return uriBuilder .build ();
395+ }
396+
397+ private String [] getHostnameAndPort (String serverAddr ) {
398+ return serverAddr .split (":" );
399+ }
266400}
0 commit comments