From 8927850f41527cd9f13598439225c1115872c596 Mon Sep 17 00:00:00 2001 From: Aaron Chung Date: Wed, 23 Jul 2025 17:21:11 -0700 Subject: [PATCH 1/3] feat - custom endpoint plugin --- .../driver/connection_plugin_chain_builder.go | 1 + .../allowed_and_blocked_hosts.go | 48 ++++ .../fixed_value_types.go | 1 + .../driver_infrastructure/plugin_helpers.go | 2 + awssql/host_info_util/host_info.go | 4 +- awssql/host_info_util/host_info_util.go | 4 + awssql/plugin_helpers/plugin_service.go | 47 +++- awssql/property_util/aws_wrapper_property.go | 186 ++++++++------ awssql/resources/en.json | 10 + awssql/utils/sliding_expiration_cache.go | 21 ++ custom-endpoint/custom_endpoint_info.go | 103 ++++++++ custom-endpoint/custom_endpoint_monitor.go | 182 ++++++++++++++ custom-endpoint/custom_endpoint_plugin.go | 228 ++++++++++++++++++ custom-endpoint/go.mod | 33 +++ custom-endpoint/go.sum | 40 +++ 15 files changed, 836 insertions(+), 74 deletions(-) create mode 100644 awssql/driver_infrastructure/allowed_and_blocked_hosts.go create mode 100644 custom-endpoint/custom_endpoint_info.go create mode 100644 custom-endpoint/custom_endpoint_monitor.go create mode 100644 custom-endpoint/custom_endpoint_plugin.go create mode 100644 custom-endpoint/go.mod create mode 100644 custom-endpoint/go.sum diff --git a/awssql/driver/connection_plugin_chain_builder.go b/awssql/driver/connection_plugin_chain_builder.go index 0546f8e8..0deac317 100644 --- a/awssql/driver/connection_plugin_chain_builder.go +++ b/awssql/driver/connection_plugin_chain_builder.go @@ -37,6 +37,7 @@ type PluginFactoryWeight struct { } var pluginWeightByCode = map[string]int{ + driver_infrastructure.CUSTOM_ENDPOINT_PLUGIN_CODE: 380, driver_infrastructure.AURORA_CONNECTION_TRACKER_PLUGIN_CODE: 400, driver_infrastructure.BLUE_GREEN_PLUGIN_CODE: 550, driver_infrastructure.READ_WRITE_SPLITTING_PLUGIN_CODE: 600, diff --git a/awssql/driver_infrastructure/allowed_and_blocked_hosts.go b/awssql/driver_infrastructure/allowed_and_blocked_hosts.go new file mode 100644 index 00000000..81c80536 --- /dev/null +++ b/awssql/driver_infrastructure/allowed_and_blocked_hosts.go @@ -0,0 +1,48 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package driver_infrastructure + +type AllowedAndBlockedHosts struct { + allowedHostIds map[string]bool + blockedHostIds map[string]bool +} + +func NewAllowedAndBlockedHosts( + allowedHostIds map[string]bool, + blockedHostIds map[string]bool) *AllowedAndBlockedHosts { + var allowedHostIdsToSet map[string]bool + var blockedHostIdsToSet map[string]bool + + if allowedHostIds != nil && len(allowedHostIds) > 0 { + allowedHostIdsToSet = allowedHostIds + } + if blockedHostIds != nil && len(blockedHostIds) > 0 { + blockedHostIdsToSet = blockedHostIds + } + return &AllowedAndBlockedHosts{ + allowedHostIds: allowedHostIdsToSet, + blockedHostIds: blockedHostIdsToSet, + } +} + +func (a *AllowedAndBlockedHosts) GetAllowedHostIds() map[string]bool { + return a.allowedHostIds +} + +func (a *AllowedAndBlockedHosts) GetBlockedHostIds() map[string]bool { + return a.blockedHostIds +} diff --git a/awssql/driver_infrastructure/fixed_value_types.go b/awssql/driver_infrastructure/fixed_value_types.go index 5c1b83fb..e4f6445d 100644 --- a/awssql/driver_infrastructure/fixed_value_types.go +++ b/awssql/driver_infrastructure/fixed_value_types.go @@ -17,6 +17,7 @@ package driver_infrastructure const ( + CUSTOM_ENDPOINT_PLUGIN_CODE string = "customEndpoint" BLUE_GREEN_PLUGIN_CODE string = "bg" READ_WRITE_SPLITTING_PLUGIN_CODE string = "readWriteSplitting" FAILOVER_PLUGIN_CODE string = "failover" diff --git a/awssql/driver_infrastructure/plugin_helpers.go b/awssql/driver_infrastructure/plugin_helpers.go index bccec98b..67431517 100644 --- a/awssql/driver_infrastructure/plugin_helpers.go +++ b/awssql/driver_infrastructure/plugin_helpers.go @@ -49,7 +49,9 @@ type PluginService interface { SetCurrentConnection(conn driver.Conn, hostInfo *host_info_util.HostInfo, skipNotificationForThisPlugin ConnectionPlugin) error GetInitialConnectionHostInfo() *host_info_util.HostInfo GetCurrentHostInfo() (*host_info_util.HostInfo, error) + GetAllHosts() []*host_info_util.HostInfo GetHosts() []*host_info_util.HostInfo + SetAllowedAndBlockedHosts(allowedAndBlockedHosts *AllowedAndBlockedHosts) AcceptsStrategy(strategy string) bool GetHostInfoByStrategy(role host_info_util.HostRole, strategy string, hosts []*host_info_util.HostInfo) (*host_info_util.HostInfo, error) GetHostSelectorStrategy(strategy string) (hostSelector HostSelector, err error) diff --git a/awssql/host_info_util/host_info.go b/awssql/host_info_util/host_info.go index 7f18e96b..0168745b 100644 --- a/awssql/host_info_util/host_info.go +++ b/awssql/host_info_util/host_info.go @@ -131,8 +131,8 @@ func (hostInfo *HostInfo) IsNil() bool { } func (hostInfo *HostInfo) String() string { - return fmt.Sprintf("HostInfo[host=%s, port=%d, %s, %s, weight=%d, %s]", - hostInfo.Host, hostInfo.Port, hostInfo.Role, hostInfo.Availability, hostInfo.Weight, hostInfo.LastUpdateTime) + return fmt.Sprintf("HostInfo[hostId=%s,host=%s, port=%d, %s, %s, weight=%d, %s]", + hostInfo.HostId, hostInfo.Host, hostInfo.Port, hostInfo.Role, hostInfo.Availability, hostInfo.Weight, hostInfo.LastUpdateTime) } func (hostInfo *HostInfo) MakeCopyWithRole(role HostRole) *HostInfo { diff --git a/awssql/host_info_util/host_info_util.go b/awssql/host_info_util/host_info_util.go index 8d54d1e3..f5a45003 100644 --- a/awssql/host_info_util/host_info_util.go +++ b/awssql/host_info_util/host_info_util.go @@ -87,6 +87,10 @@ func HaveNoHostsInCommon(hosts1 []*HostInfo, hosts2 []*HostInfo) bool { } func IsHostInList(host *HostInfo, hosts []*HostInfo) bool { + if len(hosts) < 1 { + return false + } + for _, h := range hosts { if h.Equals(host) { return true diff --git a/awssql/plugin_helpers/plugin_service.go b/awssql/plugin_helpers/plugin_service.go index 69479977..a4299304 100644 --- a/awssql/plugin_helpers/plugin_service.go +++ b/awssql/plugin_helpers/plugin_service.go @@ -23,6 +23,7 @@ import ( "log/slog" "slices" "strings" + "sync/atomic" "time" "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" @@ -54,6 +55,7 @@ type PluginServiceImpl struct { isInTransaction bool currentTx driver.Tx sessionStateService driver_infrastructure.SessionStateService + allowedAndBlockedHosts atomic.Pointer[driver_infrastructure.AllowedAndBlockedHosts] } func NewPluginServiceImpl( @@ -247,6 +249,17 @@ func (p *PluginServiceImpl) GetCurrentHostInfo() (*host_info_util.HostInfo, erro } p.currentHostInfo = host_info_util.GetWriter(p.AllHosts) + allowedHosts := p.GetHosts() + if !host_info_util.IsHostInList(p.currentHostInfo, allowedHosts) { + if p.currentHostInfo == nil { + return nil, error_util.NewGenericAwsWrapperError( + error_util.GetMessage("PluginServiceImpl.currentHostNotAllowed", p.currentHostInfo.GetHostAndPort(), utils.LogTopology(allowedHosts, ""))) + } else { + return nil, error_util.NewGenericAwsWrapperError( + error_util.GetMessage("PluginServiceImpl.currentHostNotAllowed", "", utils.LogTopology(allowedHosts, ""))) + } + } + if p.currentHostInfo.IsNil() { p.currentHostInfo = p.AllHosts[0] } @@ -259,14 +272,46 @@ func (p *PluginServiceImpl) GetCurrentHostInfo() (*host_info_util.HostInfo, erro return p.currentHostInfo, nil } -func (p *PluginServiceImpl) GetHosts() []*host_info_util.HostInfo { +// TODO: transfer some uses of #GetHost to #GetAllHosts +func (p *PluginServiceImpl) GetAllHosts() []*host_info_util.HostInfo { return p.AllHosts } +func (p *PluginServiceImpl) GetHosts() []*host_info_util.HostInfo { + hostPermissions := p.allowedAndBlockedHosts.Load() + if hostPermissions == nil { + return p.AllHosts + } + + hosts := p.AllHosts + allowedHosts := p.allowedAndBlockedHosts.Load().GetAllowedHostIds() + blockedHosts := p.allowedAndBlockedHosts.Load().GetBlockedHostIds() + + if allowedHosts != nil && len(allowedHosts) > 0 { + hosts = utils.FilterSlice(hosts, func(item *host_info_util.HostInfo) bool { + value, ok := allowedHosts[item.HostId] + return ok && value + }) + } + + if blockedHosts != nil && len(blockedHosts) > 0 { + hosts = utils.FilterSlice(hosts, func(item *host_info_util.HostInfo) bool { + value, ok := blockedHosts[item.HostId] + return !ok || !value + }) + } + + return hosts +} + func (p *PluginServiceImpl) GetInitialConnectionHostInfo() *host_info_util.HostInfo { return p.initialHostInfo } +func (p *PluginServiceImpl) SetAllowedAndBlockedHosts(allowedAndBlockedHosts *driver_infrastructure.AllowedAndBlockedHosts) { + p.allowedAndBlockedHosts.Store(allowedAndBlockedHosts) +} + func (p *PluginServiceImpl) AcceptsStrategy(strategy string) bool { return p.pluginManager.AcceptsStrategy(strategy) } diff --git a/awssql/property_util/aws_wrapper_property.go b/awssql/property_util/aws_wrapper_property.go index cb33178c..2c559bef 100644 --- a/awssql/property_util/aws_wrapper_property.go +++ b/awssql/property_util/aws_wrapper_property.go @@ -150,77 +150,82 @@ func GetRefreshRateValue(props *utils.RWMap[string, string], property AwsWrapper } var ALL_WRAPPER_PROPERTIES = map[string]bool{ - USER.Name: true, - PASSWORD.Name: true, - HOST.Name: true, - PORT.Name: true, - DATABASE.Name: true, - DRIVER_PROTOCOL.Name: true, - NET.Name: true, - SINGLE_WRITER_DSN.Name: true, - PLUGINS.Name: true, - AUTO_SORT_PLUGIN_ORDER.Name: true, - DIALECT.Name: true, - TARGET_DRIVER_DIALECT.Name: true, - TARGET_DRIVER_AUTO_REGISTER.Name: true, - CLUSTER_TOPOLOGY_REFRESH_RATE_MS.Name: true, - CLUSTER_ID.Name: true, - CLUSTER_INSTANCE_HOST_PATTERN.Name: true, - AWS_PROFILE.Name: true, - IAM_HOST.Name: true, - IAM_EXPIRATION_SEC.Name: true, - IAM_REGION.Name: true, - IAM_DEFAULT_PORT.Name: true, - SECRETS_MANAGER_SECRET_ID.Name: true, - SECRETS_MANAGER_REGION.Name: true, - SECRETS_MANAGER_ENDPOINT.Name: true, - SECRETS_MANAGER_EXPIRATION_SEC.Name: true, - FAILURE_DETECTION_TIME_MS.Name: true, - FAILURE_DETECTION_INTERVAL_MS.Name: true, - FAILURE_DETECTION_COUNT.Name: true, - MONITOR_DISPOSAL_TIME_MS.Name: true, - FAILOVER_TIMEOUT_MS.Name: true, - FAILOVER_MODE.Name: true, - FAILOVER_READER_HOST_SELECTOR_STRATEGY.Name: true, - ENABLE_CONNECT_FAILOVER.Name: true, - CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS.Name: true, - WEIGHTED_RANDOM_HOST_WEIGHT_PAIRS.Name: true, - IAM_TOKEN_EXPIRATION_SEC.Name: true, - IDP_USERNAME.Name: true, - IDP_PASSWORD.Name: true, - IDP_PORT.Name: true, - IAM_ROLE_ARN.Name: true, - IAM_IDP_ARN.Name: true, - IDP_ENDPOINT.Name: true, - RELAYING_PARTY_ID.Name: true, - DB_USER.Name: true, - APP_ID.Name: true, - HTTP_TIMEOUT_MS.Name: true, - SSL_INSECURE.Name: true, - ENABLE_TELEMETRY.Name: true, - TELEMETRY_SUBMIT_TOP_LEVEL.Name: true, - TELEMETRY_TRACES_BACKEND.Name: true, - TELEMETRY_METRICS_BACKEND.Name: true, - TELEMETRY_FAILOVER_ADDITIONAL_TOP_TRACE.Name: true, - LIMITLESS_MONITORING_INTERVAL_MS.Name: true, - LIMITLESS_MONITORING_DISPOSAL_TIME_MS.Name: true, - LIMITLESS_ROUTER_CACHE_EXPIRATION_TIME_MS.Name: true, - LIMITLESS_WAIT_FOR_ROUTER_INFO.Name: true, - LIMITLESS_GET_ROUTER_MAX_RETRIES.Name: true, - LIMITLESS_GET_ROUTER_RETRY_INTERVAL_MS.Name: true, - LIMITLESS_MAX_CONN_RETRIES.Name: true, - LIMITLESS_ROUTER_QUERY_TIMEOUT_MS.Name: true, - TRANSFER_SESSION_STATE_ON_SWITCH.Name: true, - RESET_SESSION_STATE_ON_CLOSE.Name: true, - ROLLBACK_ON_SWITCH.Name: true, - READER_HOST_SELECTOR_STRATEGY.Name: true, - BG_CONNECT_TIMEOUT_MS.Name: true, - BGD_ID.Name: true, - BG_INTERVAL_BASELINE_MS.Name: true, - BG_INTERVAL_INCREASED_MS.Name: true, - BG_INTERVAL_HIGH_MS.Name: true, - BG_SWITCHOVER_TIMEOUT_MS.Name: true, - BG_SUSPEND_NEW_BLUE_CONNECTIONS.Name: true, + USER.Name: true, + PASSWORD.Name: true, + HOST.Name: true, + PORT.Name: true, + DATABASE.Name: true, + DRIVER_PROTOCOL.Name: true, + NET.Name: true, + SINGLE_WRITER_DSN.Name: true, + PLUGINS.Name: true, + AUTO_SORT_PLUGIN_ORDER.Name: true, + DIALECT.Name: true, + TARGET_DRIVER_DIALECT.Name: true, + TARGET_DRIVER_AUTO_REGISTER.Name: true, + CLUSTER_TOPOLOGY_REFRESH_RATE_MS.Name: true, + CLUSTER_ID.Name: true, + CLUSTER_INSTANCE_HOST_PATTERN.Name: true, + AWS_PROFILE.Name: true, + IAM_HOST.Name: true, + IAM_EXPIRATION_SEC.Name: true, + IAM_REGION.Name: true, + IAM_DEFAULT_PORT.Name: true, + SECRETS_MANAGER_SECRET_ID.Name: true, + SECRETS_MANAGER_REGION.Name: true, + SECRETS_MANAGER_ENDPOINT.Name: true, + SECRETS_MANAGER_EXPIRATION_SEC.Name: true, + FAILURE_DETECTION_TIME_MS.Name: true, + FAILURE_DETECTION_INTERVAL_MS.Name: true, + FAILURE_DETECTION_COUNT.Name: true, + MONITOR_DISPOSAL_TIME_MS.Name: true, + FAILOVER_TIMEOUT_MS.Name: true, + FAILOVER_MODE.Name: true, + FAILOVER_READER_HOST_SELECTOR_STRATEGY.Name: true, + ENABLE_CONNECT_FAILOVER.Name: true, + CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS.Name: true, + WEIGHTED_RANDOM_HOST_WEIGHT_PAIRS.Name: true, + IAM_TOKEN_EXPIRATION_SEC.Name: true, + IDP_USERNAME.Name: true, + IDP_PASSWORD.Name: true, + IDP_PORT.Name: true, + IAM_ROLE_ARN.Name: true, + IAM_IDP_ARN.Name: true, + IDP_ENDPOINT.Name: true, + RELAYING_PARTY_ID.Name: true, + DB_USER.Name: true, + APP_ID.Name: true, + HTTP_TIMEOUT_MS.Name: true, + SSL_INSECURE.Name: true, + ENABLE_TELEMETRY.Name: true, + TELEMETRY_SUBMIT_TOP_LEVEL.Name: true, + TELEMETRY_TRACES_BACKEND.Name: true, + TELEMETRY_METRICS_BACKEND.Name: true, + TELEMETRY_FAILOVER_ADDITIONAL_TOP_TRACE.Name: true, + LIMITLESS_MONITORING_INTERVAL_MS.Name: true, + LIMITLESS_MONITORING_DISPOSAL_TIME_MS.Name: true, + LIMITLESS_ROUTER_CACHE_EXPIRATION_TIME_MS.Name: true, + LIMITLESS_WAIT_FOR_ROUTER_INFO.Name: true, + LIMITLESS_GET_ROUTER_MAX_RETRIES.Name: true, + LIMITLESS_GET_ROUTER_RETRY_INTERVAL_MS.Name: true, + LIMITLESS_MAX_CONN_RETRIES.Name: true, + LIMITLESS_ROUTER_QUERY_TIMEOUT_MS.Name: true, + TRANSFER_SESSION_STATE_ON_SWITCH.Name: true, + RESET_SESSION_STATE_ON_CLOSE.Name: true, + ROLLBACK_ON_SWITCH.Name: true, + READER_HOST_SELECTOR_STRATEGY.Name: true, + BG_CONNECT_TIMEOUT_MS.Name: true, + BGD_ID.Name: true, + BG_INTERVAL_BASELINE_MS.Name: true, + BG_INTERVAL_INCREASED_MS.Name: true, + BG_INTERVAL_HIGH_MS.Name: true, + BG_SWITCHOVER_TIMEOUT_MS.Name: true, + BG_SUSPEND_NEW_BLUE_CONNECTIONS.Name: true, + CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS.Name: true, + WAIT_FOR_CUSTOM_ENDPOINT_INFO.Name: true, + WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS.Name: true, + CUSTOM_ENDPOINT_MONITOR_IDLE_EXPIRATION_MS.Name: true, + CUSTOM_ENDPOINT_REGION_PROPERTY.Name: true, } var USER = AwsWrapperProperty{ @@ -742,6 +747,45 @@ var ROUND_ROBIN_DEFAULT_WEIGHT = AwsWrapperProperty{ wrapperPropertyType: WRAPPER_TYPE_INT, } +var CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS = AwsWrapperProperty{ + Name: "customEndpointInfoRefreshRateMs", + description: "Controls how frequently custom endpoint monitors fetch custom endpoint info, in milliseconds.", + defaultValue: "30000", + wrapperPropertyType: WRAPPER_TYPE_INT, +} + +var WAIT_FOR_CUSTOM_ENDPOINT_INFO = AwsWrapperProperty{ + Name: "waitForCustomEndpointInfo", + description: "Controls whether to wait for custom endpoint info to become available before connecting or executing a " + + "method. Waiting is only necessary if a connection to a given custom endpoint has not been opened or used " + + "recently. Note that disabling this may result in occasional connections to instances outside of the " + + "custom endpoint.", + defaultValue: "true", + wrapperPropertyType: WRAPPER_TYPE_BOOL, +} + +var WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS = AwsWrapperProperty{ + Name: "waitForCustomEndpointInfoTimeoutMs", + description: "Controls the maximum amount of time that the plugin will wait for custom endpoint info to be made " + + "available by the custom endpoint monitor, in milliseconds.", + defaultValue: "5000", + wrapperPropertyType: WRAPPER_TYPE_INT, +} + +var CUSTOM_ENDPOINT_MONITOR_IDLE_EXPIRATION_MS = AwsWrapperProperty{ + Name: "customEndpointMonitorExpirationMs", + description: "Controls how long a monitor should run without use before expiring and being removed, in milliseconds.", + defaultValue: "900000", + wrapperPropertyType: WRAPPER_TYPE_INT, +} + +var CUSTOM_ENDPOINT_REGION_PROPERTY = AwsWrapperProperty{ + Name: "customEndpointRegion", + description: "The region of the cluster's custom endpoints. If not specified, the region will be parsed from the URL.", + defaultValue: "", + wrapperPropertyType: WRAPPER_TYPE_STRING, +} + func RemoveInternalAwsWrapperProperties(props map[string]string) map[string]string { copyProps := map[string]string{} diff --git a/awssql/resources/en.json b/awssql/resources/en.json index 5615c9ad..ffcb45f6 100644 --- a/awssql/resources/en.json +++ b/awssql/resources/en.json @@ -85,6 +85,15 @@ "Conn.invalidTransactionIsolationLevel": "An invalid transaction isolation level was provided: '%v'.", "ConnectionPluginManager.unknownPluginCode": "Unknown plugin code: '%s'. Please ensure all plugin codes are valid and any required plugin modules have been imported.", "ConnectionProvider.unsupportedHostSelectorStrategy": "Unsupported host selection strategy '%v' specified for this connection provider '%T'. Please visit the documentation for all supported strategies.", + "CustomEndpointMonitorImpl.clearCache": "Clearing info in the custom endpoint monitor info cache.", + "CustomEndpointMonitorImpl.detectedChangeInCustomEndpointInfo": "Detected change in custom endpoint info for %s:\n %s", + "CustomEndpointMonitorImpl.error": "Encountered an exception while monitoring custom endpoint %s.", + "CustomEndpointMonitorImpl.nilResponse": "Unexpected nil response received from AWS SDK call to DescribeDBClusterEndpoints. Please check that AWS credentials have been properly set and the Custom Endpoint URL is correct.", + "CustomEndpointMonitorImpl.stoppedMonitor": "Stopped custom endpoint monitor for %s.", + "CustomEndpointMonitorImpl.unexpectedNumberOfEndpoints": "Unexpected number of custom endpoints with endpoint identifier %s in region %s. Expected 1, but found %v. Endpoints:\n%s", + "CustomEndpointPlugin.errorParsingEndpointIdentifier": "Unable to parse custom endpoint identifier from URL: %s", + "CustomEndpointPlugin.timedOutWaitingForCustomEndpointInfo": "The custom endpoint plugin timed out after %v ms while waiting for custom endpoint info for host %s", + "CustomEndpointPlugin.unableToDetermineRegion": "Unable to determine connection region. If you are using a non-standard RDS URL, please set the %s property.", "DatabaseDialect.invalidTransactionIsolationLevel": "An invalid transaction isolation level was provided: '%s'.", "DatabaseDialect.usingMonitoringHostListProvider": "Failover is enabled. Using MonitoringRdsHostListProvider.", "DatabaseDialect.usingRdsHostListProvider": "Failover is not enabled. Using RdsHostListProvider.", @@ -207,6 +216,7 @@ "OpenedConnectionTracker.unableToPopulateOpenedConnectionQueue": "The driver is unable to track this opened connection because the instance endpoint is unknown: '%s'", "PluginManager.pipelineNone": "A pipeline was requested but the created pipeline evaluated to nil.", "PluginManager.unknownPluginCode": "Unknown plugin code: '%v'.", + "PluginServiceImpl.currentHostNotAllowed": "The current host is not in the list of allowed hosts. Current host: %v. Allowed hosts: %v.", "PluginManagerImpl.invokedAgainstOldConnection": "The internal connection has changed since %v was created, skip executing method %v. This is likely due to failover. To ensure you are using the updated connection, please re-create Statement, Tx, Result and Row objects after failover.", "PluginManagerImpl.releaseResources": "Releasing resources from PluginManagerImpl.", "PluginManagerImpl.unsupportedHostSelectionStrategy": "The wrapper does not support the requested host selection strategy: %v.", diff --git a/awssql/utils/sliding_expiration_cache.go b/awssql/utils/sliding_expiration_cache.go index 1cca7fe1..fe591197 100644 --- a/awssql/utils/sliding_expiration_cache.go +++ b/awssql/utils/sliding_expiration_cache.go @@ -105,6 +105,27 @@ func (c *SlidingExpirationCache[T]) ComputeIfAbsent(key string, computeFunc func return c.cache[key].item } +func (c *SlidingExpirationCache[T]) ComputeIfAbsentWithError(key string, computeFunc func() (T, error), itemExpiration time.Duration) (T, error) { + item, ok := c.Get(key, itemExpiration) + + if ok { + return item, nil + } + + c.lock.Lock() + defer c.lock.Unlock() + item, err := computeFunc() + if err != nil { + var zeroValue T + return zeroValue, err + } + c.cache[key] = &cacheItem[T]{cacheValue[T]{ + item: item, + expirationTime: time.Now().Add(itemExpiration), + }} + return c.cache[key].item, nil +} + func (c *SlidingExpirationCache[T]) PutIfAbsent(key string, value T, expiration time.Duration) { c.cleanupIfExpired(key) c.lock.Lock() diff --git a/custom-endpoint/custom_endpoint_info.go b/custom-endpoint/custom_endpoint_info.go new file mode 100644 index 00000000..414bd11d --- /dev/null +++ b/custom-endpoint/custom_endpoint_info.go @@ -0,0 +1,103 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package custom_endpoint + +import ( + "reflect" + "strings" + + "github.com/aws/aws-sdk-go-v2/service/rds/types" +) + +type CustomEndpointInfo struct { + endpointIdentifier string + clusterIdentifier string + url string + roleType RoleType + memberListType MemberListType + members map[string]bool +} + +type RoleType string + +const ( + ANY RoleType = "ANY" + WRITER RoleType = "WRITER" + READER RoleType = "READER" +) + +type MemberListType string + +const ( + STATIC_LIST MemberListType = "STATIC_LIST" + EXCLUSION_LIST MemberListType = "EXCLUSION_LIST" +) + +func NewCustomEndpointInfo(endpoint types.DBClusterEndpoint) *CustomEndpointInfo { + var members []string + var memberListType MemberListType + + if len(endpoint.StaticMembers) > 1 { + members = endpoint.StaticMembers + memberListType = STATIC_LIST + } else { + members = endpoint.ExcludedMembers + memberListType = EXCLUSION_LIST + } + + return &CustomEndpointInfo{ + endpointIdentifier: *endpoint.DBClusterEndpointIdentifier, + clusterIdentifier: *endpoint.DBClusterIdentifier, + url: *endpoint.Endpoint, + roleType: RoleType(strings.ToUpper(*endpoint.CustomEndpointType)), + memberListType: memberListType, + members: stringSliceToSetMap(members), + } +} + +func (a *CustomEndpointInfo) Equals(b *CustomEndpointInfo) bool { + return a.endpointIdentifier == b.endpointIdentifier && + a.clusterIdentifier == b.clusterIdentifier && + a.url == b.url && + a.roleType == b.roleType && + a.memberListType == b.memberListType && + reflect.DeepEqual(a.members, b.members) +} + +func (a *CustomEndpointInfo) GetStaticMembers() map[string]bool { + if STATIC_LIST == a.memberListType { + return a.members + } else { + return nil + } +} + +func (a *CustomEndpointInfo) GetExcludedMembers() map[string]bool { + if EXCLUSION_LIST == a.memberListType { + return a.members + } else { + return nil + } +} + +func stringSliceToSetMap(stringSlice []string) map[string]bool { + setMapToReturn := make(map[string]bool) + for _, str := range stringSlice { + setMapToReturn[str] = true + } + return setMapToReturn +} diff --git a/custom-endpoint/custom_endpoint_monitor.go b/custom-endpoint/custom_endpoint_monitor.go new file mode 100644 index 00000000..085eb6e5 --- /dev/null +++ b/custom-endpoint/custom_endpoint_monitor.go @@ -0,0 +1,182 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package custom_endpoint + +import ( + "context" + "log/slog" + "sync/atomic" + "time" + + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/region_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils/telemetry" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + "github.com/aws/aws-sdk-go-v2/service/rds/types" +) + +type CustomEndpointMonitor interface { + ShouldDispose() bool + Close() + HasCustomEndpointInfo() bool +} + +var customEndpointInfoCache *utils.CacheMap[*CustomEndpointInfo] = utils.NewCache[*CustomEndpointInfo]() + +const CUSTOM_ENDPOINT_INFO_EXPIRATION_NANO = time.Minute * 5 + +type CustomEndpointMonitorImpl struct { + pluginService driver_infrastructure.PluginService + customEndpointHostInfo *host_info_util.HostInfo + endpointIdentifier string + region region_util.Region + refreshRateMs time.Duration + infoChangedCounter telemetry.TelemetryCounter + rdsClient *rds.Client + stop atomic.Bool +} + +func NewCustomEndpointMonitorImpl( + pluginService driver_infrastructure.PluginService, + customEndpointHostInfo *host_info_util.HostInfo, + endpointIdentifier string, + region region_util.Region, + refreshRateMs time.Duration, + rdsClient *rds.Client) *CustomEndpointMonitorImpl { + + monitor := &CustomEndpointMonitorImpl{ + pluginService: pluginService, + customEndpointHostInfo: customEndpointHostInfo, + endpointIdentifier: endpointIdentifier, + region: region, + refreshRateMs: refreshRateMs, + rdsClient: rdsClient, + } + + go monitor.run() + + return monitor +} + +func (monitor *CustomEndpointMonitorImpl) run() { + defer func() { + slog.Debug(error_util.GetMessage("CustomEndpointMonitorImpl.stoppedMonitor", monitor.customEndpointHostInfo.Host)) + customEndpointInfoCache.Remove(monitor.getCustomEndpointInfoCacheKey()) + }() + + for !monitor.stop.Load() { + start := time.Now() + + // RDS SDK call + command := &rds.DescribeDBClusterEndpointsInput{ + DBClusterEndpointIdentifier: &monitor.endpointIdentifier, + Filters: []types.Filter{ + { + Name: aws.String("db-cluster-endpoint-type"), + Values: []string{"custom"}, + }, + }, + } + resp, err := monitor.rdsClient.DescribeDBClusterEndpoints(context.TODO(), command) + + // Error checking + if err != nil { + slog.Error(error_util.GetMessage("CustomEndpointMonitorImpl.error", err)) + continue + } else if resp == nil || resp.DBClusterEndpoints == nil { + slog.Error(error_util.GetMessage("CustomEndpointMonitorImpl.nilResponse")) + continue + } else if len(resp.DBClusterEndpoints) != 1 { + var endpointsString string + for i, endpoint := range resp.DBClusterEndpoints { + if i > 0 { + endpointsString = endpointsString + "," + } + endpointsString = endpointsString + *endpoint.Endpoint + } + slog.Warn(error_util.GetMessage("CustomEndpointMonitorImpl.unexpectedNumberOfEndpoints", + monitor.endpointIdentifier, + monitor.region, + len(resp.DBClusterEndpoints), + endpointsString)) + time.Sleep(monitor.refreshRateMs) + continue + } + + endpointInfo := NewCustomEndpointInfo(resp.DBClusterEndpoints[0]) + cachedEndpointInfo, ok := customEndpointInfoCache.Get(monitor.getCustomEndpointInfoCacheKey()) + + if ok && endpointInfo.Equals(cachedEndpointInfo) { + elapsedTime := time.Now().Sub(start) + sleepDuration := monitor.refreshRateMs - elapsedTime + if sleepDuration < 0 { + sleepDuration = 0 + } + time.Sleep(sleepDuration) + continue + } + + slog.Debug(error_util.GetMessage("CustomEndpointMonitorImpl.detectedChangeInCustomEndpointInfo", + monitor.customEndpointHostInfo.Host, endpointInfo)) + + // Custom Endpoint Info has changed. Update set of allowed/blocked hosts + var allowedAndBlockedHosts *driver_infrastructure.AllowedAndBlockedHosts + if STATIC_LIST == endpointInfo.memberListType { + allowedAndBlockedHosts = driver_infrastructure.NewAllowedAndBlockedHosts(endpointInfo.GetStaticMembers(), nil) + } else { + allowedAndBlockedHosts = driver_infrastructure.NewAllowedAndBlockedHosts(nil, endpointInfo.GetExcludedMembers()) + } + + monitor.pluginService.SetAllowedAndBlockedHosts(allowedAndBlockedHosts) + + customEndpointInfoCache.Put(monitor.customEndpointHostInfo.GetHost(), endpointInfo, CUSTOM_ENDPOINT_INFO_EXPIRATION_NANO) + + elapsedTime := time.Now().Sub(start) + sleepDuration := monitor.refreshRateMs - elapsedTime + if sleepDuration < 0 { + sleepDuration = 0 + } + time.Sleep(sleepDuration) + } +} + +func (monitor *CustomEndpointMonitorImpl) getCustomEndpointInfoCacheKey() string { + return monitor.customEndpointHostInfo.Host +} + +func (monitor *CustomEndpointMonitorImpl) ShouldDispose() bool { + return true +} + +func (monitor *CustomEndpointMonitorImpl) HasCustomEndpointInfo() bool { + _, ok := customEndpointInfoCache.Get(monitor.customEndpointHostInfo.Host) + return ok +} + +func (monitor *CustomEndpointMonitorImpl) Close() { + slog.Debug(error_util.GetMessage("CustomEndpointMonitorImpl.stoppingMonitor", monitor.customEndpointHostInfo.Host)) + monitor.stop.Store(true) +} + +func ClearCache() { + slog.Info(error_util.GetMessage("CustomEndpointMonitorImpl.clearCache")) + customEndpointInfoCache.Clear() +} diff --git a/custom-endpoint/custom_endpoint_plugin.go b/custom-endpoint/custom_endpoint_plugin.go new file mode 100644 index 00000000..6a4bcc4b --- /dev/null +++ b/custom-endpoint/custom_endpoint_plugin.go @@ -0,0 +1,228 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package custom_endpoint + +import ( + "context" + "database/sql/driver" + "errors" + "log/slog" + "time" + + auth_helpers "github.com/aws/aws-advanced-go-wrapper/auth-helpers" + awssql "github.com/aws/aws-advanced-go-wrapper/awssql/driver" + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/plugin_helpers" + "github.com/aws/aws-advanced-go-wrapper/awssql/plugins" + "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/region_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils/telemetry" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/rds" +) + +func init() { + awssql.UsePluginFactory(driver_infrastructure.CUSTOM_ENDPOINT_PLUGIN_CODE, + NewCustomEndpointPluginFactory()) +} + +const TELEMETRY_WAIT_FOR_INFO_COUNTER = "customEndpoint.waitForInfo.counter" + +type CustomEndpointPluginFactory struct{} + +type getRdsClientFunc func(*host_info_util.HostInfo, *utils.RWMap[string, string]) (*rds.Client, error) + +func (factory CustomEndpointPluginFactory) GetInstance( + pluginService driver_infrastructure.PluginService, + props *utils.RWMap[string, string]) (driver_infrastructure.ConnectionPlugin, error) { + + return NewCustomEndpointPlugin(pluginService, getRdsClientFuncImpl, props) +} + +func getRdsClientFuncImpl(hostInfo *host_info_util.HostInfo, props *utils.RWMap[string, string]) (*rds.Client, error) { + region := property_util.GetVerifiedWrapperPropertyValue[string](props, property_util.CUSTOM_ENDPOINT_REGION_PROPERTY) + + awsCredentialsProvider, err := auth_helpers.GetAwsCredentialsProvider(*hostInfo, props.GetAllEntries()) + if err != nil { + return nil, err + } + + cfg, err := config.LoadDefaultConfig( + context.TODO(), + config.WithRegion(region), + config.WithCredentialsProvider(awsCredentialsProvider)) + if err != nil { + slog.Error("Failed to load AWS configuration", "error", err) + } + + rdsClient := rds.NewFromConfig(cfg) + return rdsClient, nil +} + +func (factory CustomEndpointPluginFactory) ClearCaches() {} + +func NewCustomEndpointPluginFactory() driver_infrastructure.ConnectionPluginFactory { + return CustomEndpointPluginFactory{} +} + +var monitorDisposalFunc utils.DisposalFunc[CustomEndpointMonitor] = func(item CustomEndpointMonitor) bool { + item.Close() + return true +} +var monitors = utils.NewSlidingExpirationCache[CustomEndpointMonitor]( + "custom-endpoint-monitor", monitorDisposalFunc) + +type CustomEndpointPlugin struct { + plugins.BaseConnectionPlugin + pluginService driver_infrastructure.PluginService + props *utils.RWMap[string, string] + shouldWaitForInfo bool + waitOnCachedInfoDurationMs int + idleMonitorExpirationMs int + waitForInfoCounter telemetry.TelemetryCounter + customEndpointHostInfo *host_info_util.HostInfo + customEndpointId string + region region_util.Region + rdsClientFunc getRdsClientFunc +} + +func NewCustomEndpointPlugin( + pluginService driver_infrastructure.PluginService, + rdsClientFunc getRdsClientFunc, + props *utils.RWMap[string, string]) (*CustomEndpointPlugin, error) { + + waitForInfoCounter, err := pluginService.GetTelemetryFactory().CreateCounter(TELEMETRY_WAIT_FOR_INFO_COUNTER) + if err != nil { + return nil, err + } + + return &CustomEndpointPlugin{ + pluginService: pluginService, + props: props, + shouldWaitForInfo: property_util.GetVerifiedWrapperPropertyValue[bool](props, property_util.WAIT_FOR_CUSTOM_ENDPOINT_INFO), + waitOnCachedInfoDurationMs: property_util.GetVerifiedWrapperPropertyValue[int](props, property_util.WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS), + idleMonitorExpirationMs: property_util.GetVerifiedWrapperPropertyValue[int](props, property_util.CUSTOM_ENDPOINT_MONITOR_IDLE_EXPIRATION_MS), + waitForInfoCounter: waitForInfoCounter, + rdsClientFunc: rdsClientFunc, + }, nil +} + +func (plugin *CustomEndpointPlugin) GetSubscribedMethods() []string { + return append([]string{ + plugin_helpers.CONNECT_METHOD, + }, utils.NETWORK_BOUND_METHODS...) +} + +func (plugin *CustomEndpointPlugin) Connect( + hostInfo *host_info_util.HostInfo, + props *utils.RWMap[string, string], + isInitialConnection bool, + connectFunc driver_infrastructure.ConnectFunc) (driver.Conn, error) { + + plugin.customEndpointHostInfo = hostInfo + plugin.customEndpointId = utils.GetRdsClusterId(hostInfo.GetHost()) + if plugin.customEndpointId == "" { + return nil, errors.New(error_util.GetMessage("CustomEndpointPlugin.errorParsingEndpointIdentifier", hostInfo.GetHost())) + } + + plugin.region = region_util.GetRegion(hostInfo.GetHost(), props, property_util.CUSTOM_ENDPOINT_REGION_PROPERTY) + if plugin.region == "" { + return nil, errors.New(error_util.GetMessage("CustomEndpointPlugin.unableToDetermineRegion", property_util.CUSTOM_ENDPOINT_REGION_PROPERTY.Name)) + } + + monitor, err := plugin.createMonitorIfAbsent(props) + if err != nil { + return nil, err + } + + if plugin.shouldWaitForInfo { + err := plugin.waitForCustomEndpointInfo(monitor) + if err != nil { + return nil, err + } + } + + return connectFunc(props) +} + +func (plugin *CustomEndpointPlugin) Execute( + _ driver.Conn, + _ string, + executeFunc driver_infrastructure.ExecuteFunc, + _ ...any) (wrappedReturnValue any, wrappedReturnValue2 any, wrappedOk bool, wrappedErr error) { + if plugin.customEndpointHostInfo == nil { + return executeFunc() + } + + monitor, err := plugin.createMonitorIfAbsent(plugin.props) + if err != nil { + return nil, nil, false, err + } + if plugin.shouldWaitForInfo { + err := plugin.waitForCustomEndpointInfo(monitor) + if err != nil { + return nil, nil, false, err + } + } + + return executeFunc() +} + +func (plugin *CustomEndpointPlugin) createMonitorIfAbsent( + props *utils.RWMap[string, string]) (CustomEndpointMonitor, error) { + refreshRateMs := time.Millisecond * time.Duration(property_util.GetRefreshRateValue(props, property_util.CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS)) + return monitors.ComputeIfAbsentWithError( + plugin.customEndpointHostInfo.Host, + func() (CustomEndpointMonitor, error) { + rdsClient, err := plugin.rdsClientFunc(plugin.customEndpointHostInfo, plugin.props) + if err != nil { + return nil, err + } + return NewCustomEndpointMonitorImpl( + plugin.pluginService, + plugin.customEndpointHostInfo, + plugin.customEndpointId, + plugin.region, + refreshRateMs, + rdsClient, + ), nil + }, 1) +} + +func (plugin *CustomEndpointPlugin) waitForCustomEndpointInfo(monitor CustomEndpointMonitor) error { + hasCustomEdnpointInfo := monitor.HasCustomEndpointInfo() + + if !hasCustomEdnpointInfo { + if plugin.waitForInfoCounter != nil { + plugin.waitForInfoCounter.Inc(plugin.pluginService.GetTelemetryContext()) + } + + waitForEndpointInfoTimeout := time.Now().Add(time.Millisecond * time.Duration(plugin.waitOnCachedInfoDurationMs)) + for !hasCustomEdnpointInfo && time.Now().Before(waitForEndpointInfoTimeout) { + time.Sleep(time.Millisecond * time.Duration(100)) + hasCustomEdnpointInfo = monitor.HasCustomEndpointInfo() + } + + if !hasCustomEdnpointInfo { + return errors.New(error_util.GetMessage("CustomEndpointPlugin.timedOutWaitingForCustomEndpointInfo")) + } + } + return nil +} diff --git a/custom-endpoint/go.mod b/custom-endpoint/go.mod new file mode 100644 index 00000000..d41c53d2 --- /dev/null +++ b/custom-endpoint/go.mod @@ -0,0 +1,33 @@ +module github.com/aws/aws-advanced-go-wrapper/custom-endpoint + +go 1.24.0 + +require ( + github.com/aws/aws-advanced-go-wrapper/auth-helpers v1.0.2 + github.com/aws/aws-advanced-go-wrapper/awssql v1.1.1 + github.com/aws/aws-sdk-go-v2/config v1.31.15 + github.com/aws/aws-sdk-go-v2/service/rds v1.101.0 +) + +require ( + github.com/aws/aws-sdk-go-v2 v1.39.4 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.19 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.11 // indirect + github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.7 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.11 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.29.8 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.38.9 // indirect + github.com/aws/smithy-go v1.23.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/nicksnyder/go-i18n/v2 v2.6.0 // indirect + golang.org/x/text v0.29.0 // indirect +) + +replace github.com/aws/aws-advanced-go-wrapper/awssql => ../awssql + +replace github.com/aws/aws-advanced-go-wrapper/auth-helpers => ../auth-helpers diff --git a/custom-endpoint/go.sum b/custom-endpoint/go.sum new file mode 100644 index 00000000..010d1195 --- /dev/null +++ b/custom-endpoint/go.sum @@ -0,0 +1,40 @@ +github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= +github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/aws/aws-sdk-go-v2 v1.39.4 h1:qTsQKcdQPHnfGYBBs+Btl8QwxJeoWcOcPcixK90mRhg= +github.com/aws/aws-sdk-go-v2 v1.39.4/go.mod h1:yWSxrnioGUZ4WVv9TgMrNUeLV3PFESn/v+6T/Su8gnM= +github.com/aws/aws-sdk-go-v2/config v1.31.15 h1:gE3M4xuNXfC/9bG4hyowGm/35uQTi7bUKeYs5e/6uvU= +github.com/aws/aws-sdk-go-v2/config v1.31.15/go.mod h1:HvnvGJoE2I95KAIW8kkWVPJ4XhdrlvwJpV6pEzFQa8o= +github.com/aws/aws-sdk-go-v2/credentials v1.18.19 h1:Jc1zzwkSY1QbkEcLujwqRTXOdvW8ppND3jRBb/VhBQc= +github.com/aws/aws-sdk-go-v2/credentials v1.18.19/go.mod h1:DIfQ9fAk5H0pGtnqfqkbSIzky82qYnGvh06ASQXXg6A= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.11 h1:X7X4YKb+c0rkI6d4uJ5tEMxXgCZ+jZ/D6mvkno8c8Uw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.11/go.mod h1:EqM6vPZQsZHYvC4Cai35UDg/f5NCEU+vp0WfbVqVcZc= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.7 h1:ugqlp7en7XTocGQKr4j0DGm4XzdRg8WZhLT1jrVR098= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.7/go.mod h1:rbByhGJsO+49UxumRGxoFnxg9ZeYX847ldd9qtyPThU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.11 h1:7AANQZkF3ihM8fbdftpjhken0TP9sBzFbV/Ze/Y4HXA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.11/go.mod h1:NTF4QCGkm6fzVwncpkFQqoquQyOolcyXfbpC98urj+c= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.11 h1:ShdtWUZT37LCAA4Mw2kJAJtzaszfSHFb5n25sdcv4YE= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.11/go.mod h1:7bUb2sSr2MZ3M/N+VyETLTQtInemHXb/Fl3s8CLzm0Y= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 h1:xtuxji5CS0JknaXoACOunXOYOQzgfTvGAc9s2QdCJA4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2/go.mod h1:zxwi0DIR0rcRcgdbl7E2MSOvxDyyXGBlScvBkARFaLQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.11 h1:GpMf3z2KJa4RnJ0ew3Hac+hRFYLZ9DDjfgXjuW+pB54= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.11/go.mod h1:6MZP3ZI4QQsgUCFTwMZA2V0sEriNQ8k2hmoHF3qjimQ= +github.com/aws/aws-sdk-go-v2/service/rds v1.101.0 h1:CWTHGWkLi+lBSt3tlFNKA8YrNG7hr1xOG6IO5XW3cpE= +github.com/aws/aws-sdk-go-v2/service/rds v1.101.0/go.mod h1:BSg3GYV7zYSk/vUsT77SlTZcYz7JmBprKslzqSuC9Nw= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.8 h1:M5nimZmugcZUO9wG7iVtROxPhiqyZX6ejS1lxlDPbTU= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.8/go.mod h1:mbef/pgKhtKRwrigPPs7SSSKZgytzP8PQ6P6JAAdqyM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.3 h1:S5GuJZpYxE0lKeMHKn+BRTz6PTFpgThyJ+5mYfux7BM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.3/go.mod h1:X4OF+BTd7HIb3L+tc4UlWHVrpgwZZIVENU15pRDVTI0= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.9 h1:Ekml5vGg6sHSZLZJQJagefnVe6PmqC2oiRkBq4F7fU0= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.9/go.mod h1:/e15V+o1zFHWdH3u7lpI3rVBcxszktIKuHKCY2/py+k= +github.com/aws/smithy-go v1.23.1 h1:sLvcH6dfAFwGkHLZ7dGiYF7aK6mg4CgKA/iDKjLDt9M= +github.com/aws/smithy-go v1.23.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/nicksnyder/go-i18n/v2 v2.6.0 h1:C/m2NNWNiTB6SK4Ao8df5EWm3JETSTIGNXBpMJTxzxQ= +github.com/nicksnyder/go-i18n/v2 v2.6.0/go.mod h1:88sRqr0C6OPyJn0/KRNaEz1uWorjxIKP7rUUcvycecE= +golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= +golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From f5ad9ce3267b7b7b5228e5ee45a41fdd5f6b13b8 Mon Sep 17 00:00:00 2001 From: Aaron Chung Date: Tue, 4 Nov 2025 03:19:13 -0800 Subject: [PATCH 2/3] custom-endpoints - minor fixes and unit tests --- .test/go.mod | 29 +- .test/go.sum | 52 +-- .test/test/custom_endpoint_plugin_test.go | 303 +++++++++++++ .test/test/mock_implementations.go | 7 + .../mock_plugin_helpers.go | 411 +++++++++--------- .../mock_custom_endpoint_monitor.go | 96 ++++ .../allowed_and_blocked_hosts.go | 4 +- awssql/host_info_util/host_info_util.go | 3 + awssql/plugin_helpers/plugin_service.go | 17 +- custom-endpoint/custom_endpoint_monitor.go | 5 +- custom-endpoint/custom_endpoint_plugin.go | 40 +- 11 files changed, 705 insertions(+), 262 deletions(-) create mode 100644 .test/test/custom_endpoint_plugin_test.go create mode 100644 .test/test/mocks/custom-endpoint/mock_custom_endpoint_monitor.go diff --git a/.test/go.mod b/.test/go.mod index eeb4dda0..724cfe60 100644 --- a/.test/go.mod +++ b/.test/go.mod @@ -7,6 +7,7 @@ require ( github.com/aws/aws-advanced-go-wrapper/auth-helpers v1.0.2 github.com/aws/aws-advanced-go-wrapper/aws-secrets-manager v1.0.2 github.com/aws/aws-advanced-go-wrapper/awssql v1.1.1 + github.com/aws/aws-advanced-go-wrapper/custom-endpoint v1.0.0 github.com/aws/aws-advanced-go-wrapper/federated-auth v1.0.2 github.com/aws/aws-advanced-go-wrapper/iam v1.0.2 github.com/aws/aws-advanced-go-wrapper/mysql-driver v1.0.2 @@ -14,11 +15,11 @@ require ( github.com/aws/aws-advanced-go-wrapper/otlp v1.0.2 github.com/aws/aws-advanced-go-wrapper/pgx-driver v1.0.2 github.com/aws/aws-advanced-go-wrapper/xray v1.0.2 - github.com/aws/aws-sdk-go-v2 v1.39.0 - github.com/aws/aws-sdk-go-v2/config v1.31.8 + github.com/aws/aws-sdk-go-v2 v1.39.4 + github.com/aws/aws-sdk-go-v2/config v1.31.15 github.com/aws/aws-sdk-go-v2/service/rds v1.107.0 github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.39.4 - github.com/aws/aws-sdk-go-v2/service/sts v1.38.4 + github.com/aws/aws-sdk-go-v2/service/sts v1.38.9 github.com/aws/aws-xray-sdk-go v1.8.5 github.com/go-sql-driver/mysql v1.9.3 github.com/golang/mock v1.6.0 @@ -41,17 +42,17 @@ require ( github.com/andybalholm/brotli v1.1.1 // indirect github.com/andybalholm/cascadia v1.3.3 // indirect github.com/aws/aws-sdk-go v1.55.7 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.18.12 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.7 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.19 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.11 // indirect github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.7 - github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.7 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.7 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.7 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.29.3 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.4 // indirect - github.com/aws/smithy-go v1.23.0 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.11 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.29.8 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.3 // indirect + github.com/aws/smithy-go v1.23.1 // indirect github.com/cenkalti/backoff/v5 v5.0.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-logr/logr v1.4.3 // indirect @@ -102,6 +103,8 @@ require ( replace github.com/aws/aws-advanced-go-wrapper/awssql => ../awssql +replace github.com/aws/aws-advanced-go-wrapper/custom-endpoint => ./../custom-endpoint + replace github.com/aws/aws-advanced-go-wrapper/pgx-driver => ./../pgx-driver replace github.com/aws/aws-advanced-go-wrapper/mysql-driver => ./../mysql-driver diff --git a/.test/go.sum b/.test/go.sum index 260f3ce6..54f85c13 100644 --- a/.test/go.sum +++ b/.test/go.sum @@ -14,40 +14,40 @@ github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kk github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA= github.com/aws/aws-sdk-go v1.55.7 h1:UJrkFq7es5CShfBwlWAC8DA077vp8PyVbQd3lqLiztE= github.com/aws/aws-sdk-go v1.55.7/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= -github.com/aws/aws-sdk-go-v2 v1.39.0 h1:xm5WV/2L4emMRmMjHFykqiA4M/ra0DJVSWUkDyBjbg4= -github.com/aws/aws-sdk-go-v2 v1.39.0/go.mod h1:sDioUELIUO9Znk23YVmIk86/9DOpkbyyVb1i/gUNFXY= -github.com/aws/aws-sdk-go-v2/config v1.31.8 h1:kQjtOLlTU4m4A64TsRcqwNChhGCwaPBt+zCQt/oWsHU= -github.com/aws/aws-sdk-go-v2/config v1.31.8/go.mod h1:QPpc7IgljrKwH0+E6/KolCgr4WPLerURiU592AYzfSY= -github.com/aws/aws-sdk-go-v2/credentials v1.18.12 h1:zmc9e1q90wMn8wQbjryy8IwA6Q4XlaL9Bx2zIqdNNbk= -github.com/aws/aws-sdk-go-v2/credentials v1.18.12/go.mod h1:3VzdRDR5u3sSJRI4kYcOSIBbeYsgtVk7dG5R/U6qLWY= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.7 h1:Is2tPmieqGS2edBnmOJIbdvOA6Op+rRpaYR60iBAwXM= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.7/go.mod h1:F1i5V5421EGci570yABvpIXgRIBPb5JM+lSkHF6Dq5w= +github.com/aws/aws-sdk-go-v2 v1.39.4 h1:qTsQKcdQPHnfGYBBs+Btl8QwxJeoWcOcPcixK90mRhg= +github.com/aws/aws-sdk-go-v2 v1.39.4/go.mod h1:yWSxrnioGUZ4WVv9TgMrNUeLV3PFESn/v+6T/Su8gnM= +github.com/aws/aws-sdk-go-v2/config v1.31.15 h1:gE3M4xuNXfC/9bG4hyowGm/35uQTi7bUKeYs5e/6uvU= +github.com/aws/aws-sdk-go-v2/config v1.31.15/go.mod h1:HvnvGJoE2I95KAIW8kkWVPJ4XhdrlvwJpV6pEzFQa8o= +github.com/aws/aws-sdk-go-v2/credentials v1.18.19 h1:Jc1zzwkSY1QbkEcLujwqRTXOdvW8ppND3jRBb/VhBQc= +github.com/aws/aws-sdk-go-v2/credentials v1.18.19/go.mod h1:DIfQ9fAk5H0pGtnqfqkbSIzky82qYnGvh06ASQXXg6A= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.11 h1:X7X4YKb+c0rkI6d4uJ5tEMxXgCZ+jZ/D6mvkno8c8Uw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.11/go.mod h1:EqM6vPZQsZHYvC4Cai35UDg/f5NCEU+vp0WfbVqVcZc= github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.7 h1:ugqlp7en7XTocGQKr4j0DGm4XzdRg8WZhLT1jrVR098= github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.7/go.mod h1:rbByhGJsO+49UxumRGxoFnxg9ZeYX847ldd9qtyPThU= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.7 h1:UCxq0X9O3xrlENdKf1r9eRJoKz/b0AfGkpp3a7FPlhg= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.7/go.mod h1:rHRoJUNUASj5Z/0eqI4w32vKvC7atoWR0jC+IkmVH8k= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.7 h1:Y6DTZUn7ZUC4th9FMBbo8LVE+1fyq3ofw+tRwkUd3PY= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.7/go.mod h1:x3XE6vMnU9QvHN/Wrx2s44kwzV2o2g5x/siw4ZUJ9g8= -github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= -github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 h1:oegbebPEMA/1Jny7kvwejowCaHz1FWZAQ94WXFNCyTM= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1/go.mod h1:kemo5Myr9ac0U9JfSjMo9yHLtw+pECEHsFtJ9tqCEI8= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.7 h1:mLgc5QIgOy26qyh5bvW+nDoAppxgn3J2WV3m9ewq7+8= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.7/go.mod h1:wXb/eQnqt8mDQIQTTmcw58B5mYGxzLGZGK8PWNFZ0BA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.11 h1:7AANQZkF3ihM8fbdftpjhken0TP9sBzFbV/Ze/Y4HXA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.11/go.mod h1:NTF4QCGkm6fzVwncpkFQqoquQyOolcyXfbpC98urj+c= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.11 h1:ShdtWUZT37LCAA4Mw2kJAJtzaszfSHFb5n25sdcv4YE= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.11/go.mod h1:7bUb2sSr2MZ3M/N+VyETLTQtInemHXb/Fl3s8CLzm0Y= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 h1:xtuxji5CS0JknaXoACOunXOYOQzgfTvGAc9s2QdCJA4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2/go.mod h1:zxwi0DIR0rcRcgdbl7E2MSOvxDyyXGBlScvBkARFaLQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.11 h1:GpMf3z2KJa4RnJ0ew3Hac+hRFYLZ9DDjfgXjuW+pB54= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.11/go.mod h1:6MZP3ZI4QQsgUCFTwMZA2V0sEriNQ8k2hmoHF3qjimQ= github.com/aws/aws-sdk-go-v2/service/rds v1.107.0 h1:PcG+YEp/ADK4JBq21G2I/PYlsq6wuDvUQqw2YEtECU8= github.com/aws/aws-sdk-go-v2/service/rds v1.107.0/go.mod h1:EVYMTmrAQr0LbGPy3FxHJHvPcP8x6byBwFJ9fUZKU3Q= github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.39.4 h1:zWISPZre5hQb3mDMCEl6uni9rJ8K2cmvp64EXF7FXkk= github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.39.4/go.mod h1:GrB/4Cn7N41psUAycqnwGDzT7qYJdUm+VnEZpyZAG4I= -github.com/aws/aws-sdk-go-v2/service/sso v1.29.3 h1:7PKX3VYsZ8LUWceVRuv0+PU+E7OtQb1lgmi5vmUE9CM= -github.com/aws/aws-sdk-go-v2/service/sso v1.29.3/go.mod h1:Ql6jE9kyyWI5JHn+61UT/Y5Z0oyVJGmgmJbZD5g4unY= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.4 h1:e0XBRn3AptQotkyBFrHAxFB8mDhAIOfsG+7KyJ0dg98= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.4/go.mod h1:XclEty74bsGBCr1s0VSaA11hQ4ZidK4viWK7rRfO88I= -github.com/aws/aws-sdk-go-v2/service/sts v1.38.4 h1:PR00NXRYgY4FWHqOGx3fC3lhVKjsp1GdloDv2ynMSd8= -github.com/aws/aws-sdk-go-v2/service/sts v1.38.4/go.mod h1:Z+Gd23v97pX9zK97+tX4ppAgqCt3Z2dIXB02CtBncK8= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.8 h1:M5nimZmugcZUO9wG7iVtROxPhiqyZX6ejS1lxlDPbTU= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.8/go.mod h1:mbef/pgKhtKRwrigPPs7SSSKZgytzP8PQ6P6JAAdqyM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.3 h1:S5GuJZpYxE0lKeMHKn+BRTz6PTFpgThyJ+5mYfux7BM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.3/go.mod h1:X4OF+BTd7HIb3L+tc4UlWHVrpgwZZIVENU15pRDVTI0= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.9 h1:Ekml5vGg6sHSZLZJQJagefnVe6PmqC2oiRkBq4F7fU0= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.9/go.mod h1:/e15V+o1zFHWdH3u7lpI3rVBcxszktIKuHKCY2/py+k= github.com/aws/aws-xray-sdk-go v1.8.5 h1:A/Gc733PHvARkjcAk+fw+0k2RT3O4VSZ+x/3YvAREfc= github.com/aws/aws-xray-sdk-go v1.8.5/go.mod h1:tDkyLXjXQ+9j49uUrFXhO9cPnpH7qp7PWkEON+KbbKs= -github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE= -github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/aws/smithy-go v1.23.1 h1:sLvcH6dfAFwGkHLZ7dGiYF7aK6mg4CgKA/iDKjLDt9M= +github.com/aws/smithy-go v1.23.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/cenkalti/backoff/v5 v5.0.2 h1:rIfFVxEf1QsI7E1ZHfp/B4DF/6QBAUhmgkxc0H7Zss8= github.com/cenkalti/backoff/v5 v5.0.2/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/.test/test/custom_endpoint_plugin_test.go b/.test/test/custom_endpoint_plugin_test.go new file mode 100644 index 00000000..52f9ed83 --- /dev/null +++ b/.test/test/custom_endpoint_plugin_test.go @@ -0,0 +1,303 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package test + +import ( + "database/sql/driver" + "errors" + "testing" + "time" + + mock_driver_infrastructure "github.com/aws/aws-advanced-go-wrapper/.test/test/mocks/awssql/driver_infrastructure" + mock_telemetry "github.com/aws/aws-advanced-go-wrapper/.test/test/mocks/awssql/util/telemetry" + mock_custom_endpoint "github.com/aws/aws-advanced-go-wrapper/.test/test/mocks/custom-endpoint" + "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" + custom_endpoint "github.com/aws/aws-advanced-go-wrapper/custom-endpoint" + "github.com/aws/aws-sdk-go-v2/service/rds" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func TestCustomEndpointPluginConnect_InvalidUrl(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockTelemetryFactory := mock_telemetry.NewMockTelemetryFactory(ctrl) + mockTelemetryCounter := mock_telemetry.NewMockTelemetryCounter(ctrl) + + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetryFactory) + mockTelemetryFactory.EXPECT().CreateCounter(custom_endpoint.TELEMETRY_WAIT_FOR_INFO_COUNTER).Return(mockTelemetryCounter, nil) + + props := utils.NewRWMap[string, string]() + rdsClientFunc := func(*host_info_util.HostInfo, *utils.RWMap[string, string]) (*rds.Client, error) { return nil, nil } + + plugin, err := custom_endpoint.NewCustomEndpointPlugin(mockPluginService, rdsClientFunc, props) + assert.NoError(t, err) + + hostInfo, err := host_info_util.NewHostInfoBuilder().SetHost("database-test-name.invalid-XYZ.us-east-2.rds.amazonaws.com").SetPort(1234).Build() + assert.NoError(t, err) + + expectedConn := &MockConn{} + mockConnFunc := func(props *utils.RWMap[string, string]) (driver.Conn, error) { + return expectedConn, nil + } + + actualConn, connErr := plugin.Connect(hostInfo, props, true, mockConnFunc) + assert.Nil(t, connErr) + assert.Equal(t, expectedConn, actualConn) +} + +func TestCustomEndpointPluginConnect_InvalidRegion(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockTelemetryFactory := mock_telemetry.NewMockTelemetryFactory(ctrl) + mockTelemetryCounter := mock_telemetry.NewMockTelemetryCounter(ctrl) + + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetryFactory) + mockTelemetryFactory.EXPECT().CreateCounter(custom_endpoint.TELEMETRY_WAIT_FOR_INFO_COUNTER).Return(mockTelemetryCounter, nil) + + props := utils.NewRWMap[string, string]() + props.Put(property_util.CUSTOM_ENDPOINT_REGION_PROPERTY.Name, "invalid-region") + rdsClientFunc := func(*host_info_util.HostInfo, *utils.RWMap[string, string]) (*rds.Client, error) { return nil, nil } + + plugin, err := custom_endpoint.NewCustomEndpointPlugin(mockPluginService, rdsClientFunc, props) + assert.NoError(t, err) + + hostInfo, err := host_info_util.NewHostInfoBuilder().SetHost("database-test-name.cluster-custom-XYZ.invalid-region.rds.amazonaws.com").SetPort(1234).Build() + assert.NoError(t, err) + + mockConnFunc := func(props *utils.RWMap[string, string]) (driver.Conn, error) { + return &MockConn{}, nil + } + + _, connErr := plugin.Connect(hostInfo, props, true, mockConnFunc) + assert.NotNil(t, connErr) + assert.Equal(t, + error_util.GetMessage("CustomEndpointPlugin.unableToDetermineRegion", property_util.CUSTOM_ENDPOINT_REGION_PROPERTY.Name), + connErr.Error()) +} + +func TestCustomEndpointPluginConnect_DontWaitForCustomEndpointInfo(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockTelemetryFactory := mock_telemetry.NewMockTelemetryFactory(ctrl) + mockTelemetryCounter := mock_telemetry.NewMockTelemetryCounter(ctrl) + + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetryFactory) + mockTelemetryFactory.EXPECT().CreateCounter(custom_endpoint.TELEMETRY_WAIT_FOR_INFO_COUNTER).Return(mockTelemetryCounter, nil) + + mockMonitor := mock_custom_endpoint.NewMockCustomEndpointMonitor(ctrl) + mockMonitor.EXPECT().HasCustomEndpointInfo().Return(true).Times(0) + mockMonitor.EXPECT().Close() + + props := utils.NewRWMap[string, string]() + props.Put(property_util.WAIT_FOR_CUSTOM_ENDPOINT_INFO.Name, "false") + rdsClientFunc := func(*host_info_util.HostInfo, *utils.RWMap[string, string]) (*rds.Client, error) { return nil, nil } + + plugin, err := custom_endpoint.NewCustomEndpointPlugin(mockPluginService, rdsClientFunc, props) + assert.NoError(t, err) + defer custom_endpoint.CustomEndpointPluginFactory{}.ClearCaches() + + hostInfo, err := host_info_util.NewHostInfoBuilder().SetHost("database-test-name.cluster-custom-XYZ.us-east-2.rds.amazonaws.com").SetPort(1234).Build() + assert.NoError(t, err) + + custom_endpoint.CUSTOM_ENDPOINT_MONITORS.Put(hostInfo.Host, mockMonitor, time.Minute*1) + + expectedConn := &MockConn{} + mockConnFunc := func(props *utils.RWMap[string, string]) (driver.Conn, error) { + return expectedConn, nil + } + + actualConn, actualConnErr := plugin.Connect(hostInfo, props, true, mockConnFunc) + + assert.Equal(t, expectedConn, actualConn) + assert.Nil(t, actualConnErr) +} + +func TestCustomEndpointPluginConnect_WaitForCustomEndpointInfo(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockTelemetryFactory := mock_telemetry.NewMockTelemetryFactory(ctrl) + mockTelemetryCounter := mock_telemetry.NewMockTelemetryCounter(ctrl) + + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetryFactory) + mockPluginService.EXPECT().GetTelemetryContext().Return(nil) + mockTelemetryFactory.EXPECT().CreateCounter(custom_endpoint.TELEMETRY_WAIT_FOR_INFO_COUNTER).Return(mockTelemetryCounter, nil) + mockTelemetryCounter.EXPECT().Inc(gomock.Any()) + + mockMonitor := mock_custom_endpoint.NewMockCustomEndpointMonitor(ctrl) + mockMonitor.EXPECT().Close() + mockedHasCustomEndpointInfoCalls0 := mockMonitor.EXPECT().HasCustomEndpointInfo().Return(false).Times(5) + mockedHasCustomEndpointInfoCalls1 := mockMonitor.EXPECT().HasCustomEndpointInfo().Return(true).Times(1) + gomock.InOrder(mockedHasCustomEndpointInfoCalls0, mockedHasCustomEndpointInfoCalls1) + + props := utils.NewRWMap[string, string]() + props.Put(property_util.WAIT_FOR_CUSTOM_ENDPOINT_INFO.Name, "true") + rdsClientFunc := func(*host_info_util.HostInfo, *utils.RWMap[string, string]) (*rds.Client, error) { return nil, nil } + + plugin, err := custom_endpoint.NewCustomEndpointPlugin(mockPluginService, rdsClientFunc, props) + assert.NoError(t, err) + defer custom_endpoint.CustomEndpointPluginFactory{}.ClearCaches() + + hostInfo, err := host_info_util.NewHostInfoBuilder().SetHost("database-test-name.cluster-custom-XYZ.us-east-2.rds.amazonaws.com").SetPort(1234).Build() + assert.NoError(t, err) + + custom_endpoint.CUSTOM_ENDPOINT_MONITORS.Put(hostInfo.Host, mockMonitor, time.Minute*1) + + expectedConn := &MockConn{} + mockConnFunc := func(props *utils.RWMap[string, string]) (driver.Conn, error) { + return expectedConn, nil + } + + actualConn, actualConnErr := plugin.Connect(hostInfo, props, true, mockConnFunc) + + assert.Equal(t, expectedConn, actualConn) + assert.Nil(t, actualConnErr) +} + +func TestCustomEndpointPluginExecute_CustomEndpointHostNotSet(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockTelemetryFactory := mock_telemetry.NewMockTelemetryFactory(ctrl) + mockTelemetryCounter := mock_telemetry.NewMockTelemetryCounter(ctrl) + + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetryFactory) + mockTelemetryFactory.EXPECT().CreateCounter(custom_endpoint.TELEMETRY_WAIT_FOR_INFO_COUNTER).Return(mockTelemetryCounter, nil) + + props := utils.NewRWMap[string, string]() + props.Put(property_util.CUSTOM_ENDPOINT_REGION_PROPERTY.Name, "invalid-region") + rdsClientFunc := func(*host_info_util.HostInfo, *utils.RWMap[string, string]) (*rds.Client, error) { return nil, nil } + + plugin, err := custom_endpoint.NewCustomEndpointPlugin(mockPluginService, rdsClientFunc, props) + assert.NoError(t, err) + + expectedResult0 := "result0" + expectedResult1 := "result1" + expectedBool := true + expectedErr := errors.New("expectedError") + mockExecuteFunc := func() (any, any, bool, error) { + return expectedResult0, expectedResult1, expectedBool, expectedErr + } + actualResult0, actualResult1, actualBool, actualErr := plugin.Execute(nil, "", mockExecuteFunc) + + assert.Equal(t, expectedResult0, actualResult0) + assert.Equal(t, expectedResult1, actualResult1) + assert.Equal(t, expectedBool, actualBool) + assert.Equal(t, expectedErr, actualErr) +} + +func TestCustomEndpointPluginExecute_DontWaitForCustomEndpointInfo(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockTelemetryFactory := mock_telemetry.NewMockTelemetryFactory(ctrl) + mockTelemetryCounter := mock_telemetry.NewMockTelemetryCounter(ctrl) + + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetryFactory) + mockTelemetryFactory.EXPECT().CreateCounter(custom_endpoint.TELEMETRY_WAIT_FOR_INFO_COUNTER).Return(mockTelemetryCounter, nil) + + mockMonitor := mock_custom_endpoint.NewMockCustomEndpointMonitor(ctrl) + mockMonitor.EXPECT().Close() + mockMonitor.EXPECT().HasCustomEndpointInfo().Return(true).Times(0) + + props := utils.NewRWMap[string, string]() + props.Put(property_util.WAIT_FOR_CUSTOM_ENDPOINT_INFO.Name, "false") + rdsClientFunc := func(*host_info_util.HostInfo, *utils.RWMap[string, string]) (*rds.Client, error) { return nil, nil } + + plugin, err := custom_endpoint.NewCustomEndpointPlugin(mockPluginService, rdsClientFunc, props) + assert.NoError(t, err) + defer custom_endpoint.CustomEndpointPluginFactory{}.ClearCaches() + + hostInfo, err := host_info_util.NewHostInfoBuilder().SetHost("database-test-name.cluster-custom-XYZ.us-east-2.rds.amazonaws.com").SetPort(1234).Build() + assert.NoError(t, err) + + custom_endpoint.CUSTOM_ENDPOINT_MONITORS.Put(hostInfo.Host, mockMonitor, time.Minute*1) + + expectedResult0 := "result0" + expectedResult1 := "result1" + expectedBool := true + expectedErr := errors.New("expectedError") + mockExecuteFunc := func() (any, any, bool, error) { + return expectedResult0, expectedResult1, expectedBool, expectedErr + } + actualResult0, actualResult1, actualBool, actualErr := plugin.Execute(nil, "", mockExecuteFunc) + + assert.Equal(t, expectedResult0, actualResult0) + assert.Equal(t, expectedResult1, actualResult1) + assert.Equal(t, expectedBool, actualBool) + assert.Equal(t, expectedErr, actualErr) +} + +func TestCustomEndpointPluginExecute_WaitForCustomEndpointInfo(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockTelemetryFactory := mock_telemetry.NewMockTelemetryFactory(ctrl) + mockTelemetryCounter := mock_telemetry.NewMockTelemetryCounter(ctrl) + + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetryFactory) + mockPluginService.EXPECT().GetTelemetryContext().Return(nil) + mockTelemetryFactory.EXPECT().CreateCounter(custom_endpoint.TELEMETRY_WAIT_FOR_INFO_COUNTER).Return(mockTelemetryCounter, nil) + mockTelemetryCounter.EXPECT().Inc(gomock.Any()) + + mockMonitor := mock_custom_endpoint.NewMockCustomEndpointMonitor(ctrl) + mockMonitor.EXPECT().Close() + mockedHasCustomEndpointInfoCalls0 := mockMonitor.EXPECT().HasCustomEndpointInfo().Return(false).Times(5) + mockedHasCustomEndpointInfoCalls1 := mockMonitor.EXPECT().HasCustomEndpointInfo().Return(true).Times(1) + gomock.InOrder(mockedHasCustomEndpointInfoCalls0, mockedHasCustomEndpointInfoCalls1) + + props := utils.NewRWMap[string, string]() + props.Put(property_util.WAIT_FOR_CUSTOM_ENDPOINT_INFO.Name, "true") + rdsClientFunc := func(*host_info_util.HostInfo, *utils.RWMap[string, string]) (*rds.Client, error) { return nil, nil } + + hostInfo, err := host_info_util.NewHostInfoBuilder().SetHost("database-test-name.cluster-custom-XYZ.us-east-2.rds.amazonaws.com").SetPort(1234).Build() + assert.NoError(t, err) + + plugin, err := custom_endpoint.NewCustomEndpointPluginWithHostInfo(mockPluginService, rdsClientFunc, props, hostInfo) + assert.NoError(t, err) + defer custom_endpoint.CustomEndpointPluginFactory{}.ClearCaches() + + custom_endpoint.CUSTOM_ENDPOINT_MONITORS.Put(hostInfo.Host, mockMonitor, time.Minute*1) + + expectedResult0 := "result0" + expectedResult1 := "result1" + expectedBool := true + expectedErr := errors.New("expectedError") + mockExecuteFunc := func() (any, any, bool, error) { + return expectedResult0, expectedResult1, expectedBool, expectedErr + } + actualResult0, actualResult1, actualBool, actualErr := plugin.Execute(nil, "", mockExecuteFunc) + + assert.Equal(t, expectedResult0, actualResult0) + assert.Equal(t, expectedResult1, actualResult1) + assert.Equal(t, expectedBool, actualBool) + assert.Equal(t, expectedErr, actualErr) +} diff --git a/.test/test/mock_implementations.go b/.test/test/mock_implementations.go index 32ef092d..838f7930 100644 --- a/.test/test/mock_implementations.go +++ b/.test/test/mock_implementations.go @@ -403,6 +403,10 @@ func (p *MockPluginService) GetHosts() []*host_info_util.HostInfo { return nil } +func (p *MockPluginService) GetAllHosts() []*host_info_util.HostInfo { + return nil +} + func (p *MockPluginService) AcceptsStrategy(_ string) bool { return false } @@ -439,6 +443,9 @@ func (p *MockPluginService) SetHostListProvider(_ driver_infrastructure.HostList func (p *MockPluginService) SetInitialConnectionHostInfo(_ *host_info_util.HostInfo) {} +func (p *MockPluginService) SetAllowedAndBlockedHosts(allowedAndBlockedHosts *driver_infrastructure.AllowedAndBlockedHosts) { +} + func (p *MockPluginService) IsStaticHostListProvider() bool { return false } diff --git a/.test/test/mocks/awssql/driver_infrastructure/mock_plugin_helpers.go b/.test/test/mocks/awssql/driver_infrastructure/mock_plugin_helpers.go index 38b3050f..1b8a8928 100644 --- a/.test/test/mocks/awssql/driver_infrastructure/mock_plugin_helpers.go +++ b/.test/test/mocks/awssql/driver_infrastructure/mock_plugin_helpers.go @@ -1,21 +1,10 @@ -/* - Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"). - You may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure (interfaces: HostListProviderService,PluginService,PluginManager,CanReleaseResources) +// Source: awssql/driver_infrastructure/plugin_helpers.go +// +// Generated by this command: +// +// mockgen -source=awssql/driver_infrastructure/plugin_helpers.go -destination=.test/test/mocks/awssql/driver_infrastructure/mock_plugin_helpers.go package=mock_driver_infrastructure +// // Package mock_driver_infrastructure is a generated GoMock package. package mock_driver_infrastructure @@ -27,7 +16,7 @@ import ( driver_infrastructure "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" host_info_util "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" - "github.com/aws/aws-advanced-go-wrapper/awssql/utils" + utils "github.com/aws/aws-advanced-go-wrapper/awssql/utils" telemetry "github.com/aws/aws-advanced-go-wrapper/awssql/utils/telemetry" gomock "github.com/golang/mock/gomock" ) @@ -36,6 +25,7 @@ import ( type MockHostListProviderService struct { ctrl *gomock.Controller recorder *MockHostListProviderServiceMockRecorder + isgomock struct{} } // MockHostListProviderServiceMockRecorder is the mock recorder for MockHostListProviderService. @@ -56,17 +46,17 @@ func (m *MockHostListProviderService) EXPECT() *MockHostListProviderServiceMockR } // CreateHostListProvider mocks base method. -func (m *MockHostListProviderService) CreateHostListProvider(arg0 *utils.RWMap[string, string]) driver_infrastructure.HostListProvider { +func (m *MockHostListProviderService) CreateHostListProvider(props *utils.RWMap[string, string]) driver_infrastructure.HostListProvider { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateHostListProvider", arg0) + ret := m.ctrl.Call(m, "CreateHostListProvider", props) ret0, _ := ret[0].(driver_infrastructure.HostListProvider) return ret0 } // CreateHostListProvider indicates an expected call of CreateHostListProvider. -func (mr *MockHostListProviderServiceMockRecorder) CreateHostListProvider(arg0 interface{}) *gomock.Call { +func (mr *MockHostListProviderServiceMockRecorder) CreateHostListProvider(props any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHostListProvider", reflect.TypeOf((*MockHostListProviderService)(nil).CreateHostListProvider), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHostListProvider", reflect.TypeOf((*MockHostListProviderService)(nil).CreateHostListProvider), props) } // GetCurrentConnection mocks base method. @@ -126,33 +116,34 @@ func (mr *MockHostListProviderServiceMockRecorder) IsStaticHostListProvider() *g } // SetHostListProvider mocks base method. -func (m *MockHostListProviderService) SetHostListProvider(arg0 driver_infrastructure.HostListProvider) { +func (m *MockHostListProviderService) SetHostListProvider(hostListProvider driver_infrastructure.HostListProvider) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetHostListProvider", arg0) + m.ctrl.Call(m, "SetHostListProvider", hostListProvider) } // SetHostListProvider indicates an expected call of SetHostListProvider. -func (mr *MockHostListProviderServiceMockRecorder) SetHostListProvider(arg0 interface{}) *gomock.Call { +func (mr *MockHostListProviderServiceMockRecorder) SetHostListProvider(hostListProvider any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHostListProvider", reflect.TypeOf((*MockHostListProviderService)(nil).SetHostListProvider), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHostListProvider", reflect.TypeOf((*MockHostListProviderService)(nil).SetHostListProvider), hostListProvider) } // SetInitialConnectionHostInfo mocks base method. -func (m *MockHostListProviderService) SetInitialConnectionHostInfo(arg0 *host_info_util.HostInfo) { +func (m *MockHostListProviderService) SetInitialConnectionHostInfo(info *host_info_util.HostInfo) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetInitialConnectionHostInfo", arg0) + m.ctrl.Call(m, "SetInitialConnectionHostInfo", info) } // SetInitialConnectionHostInfo indicates an expected call of SetInitialConnectionHostInfo. -func (mr *MockHostListProviderServiceMockRecorder) SetInitialConnectionHostInfo(arg0 interface{}) *gomock.Call { +func (mr *MockHostListProviderServiceMockRecorder) SetInitialConnectionHostInfo(info any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetInitialConnectionHostInfo", reflect.TypeOf((*MockHostListProviderService)(nil).SetInitialConnectionHostInfo), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetInitialConnectionHostInfo", reflect.TypeOf((*MockHostListProviderService)(nil).SetInitialConnectionHostInfo), info) } // MockPluginService is a mock of PluginService interface. type MockPluginService struct { ctrl *gomock.Controller recorder *MockPluginServiceMockRecorder + isgomock struct{} } // MockPluginServiceMockRecorder is the mock recorder for MockPluginService. @@ -173,117 +164,131 @@ func (m *MockPluginService) EXPECT() *MockPluginServiceMockRecorder { } // AcceptsStrategy mocks base method. -func (m *MockPluginService) AcceptsStrategy(arg0 string) bool { +func (m *MockPluginService) AcceptsStrategy(strategy string) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptsStrategy", arg0) + ret := m.ctrl.Call(m, "AcceptsStrategy", strategy) ret0, _ := ret[0].(bool) return ret0 } // AcceptsStrategy indicates an expected call of AcceptsStrategy. -func (mr *MockPluginServiceMockRecorder) AcceptsStrategy(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) AcceptsStrategy(strategy any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptsStrategy", reflect.TypeOf((*MockPluginService)(nil).AcceptsStrategy), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptsStrategy", reflect.TypeOf((*MockPluginService)(nil).AcceptsStrategy), strategy) } // Connect mocks base method. -func (m *MockPluginService) Connect(arg0 *host_info_util.HostInfo, arg1 *utils.RWMap[string, string], arg2 driver_infrastructure.ConnectionPlugin) (driver.Conn, error) { +func (m *MockPluginService) Connect(hostInfo *host_info_util.HostInfo, props *utils.RWMap[string, string], pluginToSkip driver_infrastructure.ConnectionPlugin) (driver.Conn, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Connect", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "Connect", hostInfo, props, pluginToSkip) ret0, _ := ret[0].(driver.Conn) ret1, _ := ret[1].(error) return ret0, ret1 } // Connect indicates an expected call of Connect. -func (mr *MockPluginServiceMockRecorder) Connect(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) Connect(hostInfo, props, pluginToSkip any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockPluginService)(nil).Connect), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockPluginService)(nil).Connect), hostInfo, props, pluginToSkip) } // CreateHostListProvider mocks base method. -func (m *MockPluginService) CreateHostListProvider(arg0 *utils.RWMap[string, string]) driver_infrastructure.HostListProvider { +func (m *MockPluginService) CreateHostListProvider(props *utils.RWMap[string, string]) driver_infrastructure.HostListProvider { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateHostListProvider", arg0) + ret := m.ctrl.Call(m, "CreateHostListProvider", props) ret0, _ := ret[0].(driver_infrastructure.HostListProvider) return ret0 } // CreateHostListProvider indicates an expected call of CreateHostListProvider. -func (mr *MockPluginServiceMockRecorder) CreateHostListProvider(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) CreateHostListProvider(props any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHostListProvider", reflect.TypeOf((*MockPluginService)(nil).CreateHostListProvider), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHostListProvider", reflect.TypeOf((*MockPluginService)(nil).CreateHostListProvider), props) } // FillAliases mocks base method. -func (m *MockPluginService) FillAliases(arg0 driver.Conn, arg1 *host_info_util.HostInfo) { +func (m *MockPluginService) FillAliases(conn driver.Conn, hostInfo *host_info_util.HostInfo) { m.ctrl.T.Helper() - m.ctrl.Call(m, "FillAliases", arg0, arg1) + m.ctrl.Call(m, "FillAliases", conn, hostInfo) } // FillAliases indicates an expected call of FillAliases. -func (mr *MockPluginServiceMockRecorder) FillAliases(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) FillAliases(conn, hostInfo any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FillAliases", reflect.TypeOf((*MockPluginService)(nil).FillAliases), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FillAliases", reflect.TypeOf((*MockPluginService)(nil).FillAliases), conn, hostInfo) } // ForceConnect mocks base method. -func (m *MockPluginService) ForceConnect(arg0 *host_info_util.HostInfo, arg1 *utils.RWMap[string, string]) (driver.Conn, error) { +func (m *MockPluginService) ForceConnect(hostInfo *host_info_util.HostInfo, props *utils.RWMap[string, string]) (driver.Conn, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ForceConnect", arg0, arg1) + ret := m.ctrl.Call(m, "ForceConnect", hostInfo, props) ret0, _ := ret[0].(driver.Conn) ret1, _ := ret[1].(error) return ret0, ret1 } // ForceConnect indicates an expected call of ForceConnect. -func (mr *MockPluginServiceMockRecorder) ForceConnect(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) ForceConnect(hostInfo, props any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForceConnect", reflect.TypeOf((*MockPluginService)(nil).ForceConnect), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForceConnect", reflect.TypeOf((*MockPluginService)(nil).ForceConnect), hostInfo, props) } // ForceRefreshHostList mocks base method. -func (m *MockPluginService) ForceRefreshHostList(arg0 driver.Conn) error { +func (m *MockPluginService) ForceRefreshHostList(conn driver.Conn) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ForceRefreshHostList", arg0) + ret := m.ctrl.Call(m, "ForceRefreshHostList", conn) ret0, _ := ret[0].(error) return ret0 } // ForceRefreshHostList indicates an expected call of ForceRefreshHostList. -func (mr *MockPluginServiceMockRecorder) ForceRefreshHostList(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) ForceRefreshHostList(conn any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForceRefreshHostList", reflect.TypeOf((*MockPluginService)(nil).ForceRefreshHostList), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForceRefreshHostList", reflect.TypeOf((*MockPluginService)(nil).ForceRefreshHostList), conn) } // ForceRefreshHostListWithTimeout mocks base method. -func (m *MockPluginService) ForceRefreshHostListWithTimeout(arg0 bool, arg1 int) (bool, error) { +func (m *MockPluginService) ForceRefreshHostListWithTimeout(shouldVerifyWriter bool, timeoutMs int) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ForceRefreshHostListWithTimeout", arg0, arg1) + ret := m.ctrl.Call(m, "ForceRefreshHostListWithTimeout", shouldVerifyWriter, timeoutMs) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } // ForceRefreshHostListWithTimeout indicates an expected call of ForceRefreshHostListWithTimeout. -func (mr *MockPluginServiceMockRecorder) ForceRefreshHostListWithTimeout(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) ForceRefreshHostListWithTimeout(shouldVerifyWriter, timeoutMs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForceRefreshHostListWithTimeout", reflect.TypeOf((*MockPluginService)(nil).ForceRefreshHostListWithTimeout), shouldVerifyWriter, timeoutMs) +} + +// GetAllHosts mocks base method. +func (m *MockPluginService) GetAllHosts() []*host_info_util.HostInfo { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllHosts") + ret0, _ := ret[0].([]*host_info_util.HostInfo) + return ret0 +} + +// GetAllHosts indicates an expected call of GetAllHosts. +func (mr *MockPluginServiceMockRecorder) GetAllHosts() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForceRefreshHostListWithTimeout", reflect.TypeOf((*MockPluginService)(nil).ForceRefreshHostListWithTimeout), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllHosts", reflect.TypeOf((*MockPluginService)(nil).GetAllHosts)) } // GetBgStatus mocks base method. -func (m *MockPluginService) GetBgStatus(arg0 string) (driver_infrastructure.BlueGreenStatus, bool) { +func (m *MockPluginService) GetBgStatus(id string) (driver_infrastructure.BlueGreenStatus, bool) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetBgStatus", arg0) + ret := m.ctrl.Call(m, "GetBgStatus", id) ret0, _ := ret[0].(driver_infrastructure.BlueGreenStatus) ret1, _ := ret[1].(bool) return ret0, ret1 } // GetBgStatus indicates an expected call of GetBgStatus. -func (mr *MockPluginServiceMockRecorder) GetBgStatus(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) GetBgStatus(id any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBgStatus", reflect.TypeOf((*MockPluginService)(nil).GetBgStatus), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBgStatus", reflect.TypeOf((*MockPluginService)(nil).GetBgStatus), id) } // GetConnectionProvider mocks base method. @@ -372,18 +377,18 @@ func (mr *MockPluginServiceMockRecorder) GetDialect() *gomock.Call { } // GetHostInfoByStrategy mocks base method. -func (m *MockPluginService) GetHostInfoByStrategy(arg0 host_info_util.HostRole, arg1 string, arg2 []*host_info_util.HostInfo) (*host_info_util.HostInfo, error) { +func (m *MockPluginService) GetHostInfoByStrategy(role host_info_util.HostRole, strategy string, hosts []*host_info_util.HostInfo) (*host_info_util.HostInfo, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHostInfoByStrategy", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetHostInfoByStrategy", role, strategy, hosts) ret0, _ := ret[0].(*host_info_util.HostInfo) ret1, _ := ret[1].(error) return ret0, ret1 } // GetHostInfoByStrategy indicates an expected call of GetHostInfoByStrategy. -func (mr *MockPluginServiceMockRecorder) GetHostInfoByStrategy(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) GetHostInfoByStrategy(role, strategy, hosts any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostInfoByStrategy", reflect.TypeOf((*MockPluginService)(nil).GetHostInfoByStrategy), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostInfoByStrategy", reflect.TypeOf((*MockPluginService)(nil).GetHostInfoByStrategy), role, strategy, hosts) } // GetHostListProvider mocks base method. @@ -409,24 +414,24 @@ func (m *MockPluginService) GetHostRole(arg0 driver.Conn) host_info_util.HostRol } // GetHostRole indicates an expected call of GetHostRole. -func (mr *MockPluginServiceMockRecorder) GetHostRole(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) GetHostRole(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostRole", reflect.TypeOf((*MockPluginService)(nil).GetHostRole), arg0) } // GetHostSelectorStrategy mocks base method. -func (m *MockPluginService) GetHostSelectorStrategy(arg0 string) (driver_infrastructure.HostSelector, error) { +func (m *MockPluginService) GetHostSelectorStrategy(strategy string) (driver_infrastructure.HostSelector, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHostSelectorStrategy", arg0) + ret := m.ctrl.Call(m, "GetHostSelectorStrategy", strategy) ret0, _ := ret[0].(driver_infrastructure.HostSelector) ret1, _ := ret[1].(error) return ret0, ret1 } // GetHostSelectorStrategy indicates an expected call of GetHostSelectorStrategy. -func (mr *MockPluginServiceMockRecorder) GetHostSelectorStrategy(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) GetHostSelectorStrategy(strategy any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostSelectorStrategy", reflect.TypeOf((*MockPluginService)(nil).GetHostSelectorStrategy), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostSelectorStrategy", reflect.TypeOf((*MockPluginService)(nil).GetHostSelectorStrategy), strategy) } // GetHosts mocks base method. @@ -514,33 +519,33 @@ func (mr *MockPluginServiceMockRecorder) GetTelemetryFactory() *gomock.Call { } // GetUpdatedHostListWithTimeout mocks base method. -func (m *MockPluginService) GetUpdatedHostListWithTimeout(arg0 bool, arg1 int) ([]*host_info_util.HostInfo, error) { +func (m *MockPluginService) GetUpdatedHostListWithTimeout(shouldVerifyWriter bool, timeoutMs int) ([]*host_info_util.HostInfo, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUpdatedHostListWithTimeout", arg0, arg1) + ret := m.ctrl.Call(m, "GetUpdatedHostListWithTimeout", shouldVerifyWriter, timeoutMs) ret0, _ := ret[0].([]*host_info_util.HostInfo) ret1, _ := ret[1].(error) return ret0, ret1 } // GetUpdatedHostListWithTimeout indicates an expected call of GetUpdatedHostListWithTimeout. -func (mr *MockPluginServiceMockRecorder) GetUpdatedHostListWithTimeout(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) GetUpdatedHostListWithTimeout(shouldVerifyWriter, timeoutMs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUpdatedHostListWithTimeout", reflect.TypeOf((*MockPluginService)(nil).GetUpdatedHostListWithTimeout), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUpdatedHostListWithTimeout", reflect.TypeOf((*MockPluginService)(nil).GetUpdatedHostListWithTimeout), shouldVerifyWriter, timeoutMs) } // IdentifyConnection mocks base method. -func (m *MockPluginService) IdentifyConnection(arg0 driver.Conn) (*host_info_util.HostInfo, error) { +func (m *MockPluginService) IdentifyConnection(conn driver.Conn) (*host_info_util.HostInfo, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IdentifyConnection", arg0) + ret := m.ctrl.Call(m, "IdentifyConnection", conn) ret0, _ := ret[0].(*host_info_util.HostInfo) ret1, _ := ret[1].(error) return ret0, ret1 } // IdentifyConnection indicates an expected call of IdentifyConnection. -func (mr *MockPluginServiceMockRecorder) IdentifyConnection(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) IdentifyConnection(conn any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IdentifyConnection", reflect.TypeOf((*MockPluginService)(nil).IdentifyConnection), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IdentifyConnection", reflect.TypeOf((*MockPluginService)(nil).IdentifyConnection), conn) } // IsInTransaction mocks base method. @@ -558,45 +563,45 @@ func (mr *MockPluginServiceMockRecorder) IsInTransaction() *gomock.Call { } // IsLoginError mocks base method. -func (m *MockPluginService) IsLoginError(arg0 error) bool { +func (m *MockPluginService) IsLoginError(err error) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsLoginError", arg0) + ret := m.ctrl.Call(m, "IsLoginError", err) ret0, _ := ret[0].(bool) return ret0 } // IsLoginError indicates an expected call of IsLoginError. -func (mr *MockPluginServiceMockRecorder) IsLoginError(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) IsLoginError(err any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsLoginError", reflect.TypeOf((*MockPluginService)(nil).IsLoginError), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsLoginError", reflect.TypeOf((*MockPluginService)(nil).IsLoginError), err) } // IsNetworkError mocks base method. -func (m *MockPluginService) IsNetworkError(arg0 error) bool { +func (m *MockPluginService) IsNetworkError(err error) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsNetworkError", arg0) + ret := m.ctrl.Call(m, "IsNetworkError", err) ret0, _ := ret[0].(bool) return ret0 } // IsNetworkError indicates an expected call of IsNetworkError. -func (mr *MockPluginServiceMockRecorder) IsNetworkError(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) IsNetworkError(err any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNetworkError", reflect.TypeOf((*MockPluginService)(nil).IsNetworkError), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNetworkError", reflect.TypeOf((*MockPluginService)(nil).IsNetworkError), err) } // IsPluginInUse mocks base method. -func (m *MockPluginService) IsPluginInUse(arg0 string) bool { +func (m *MockPluginService) IsPluginInUse(pluginName string) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsPluginInUse", arg0) + ret := m.ctrl.Call(m, "IsPluginInUse", pluginName) ret0, _ := ret[0].(bool) return ret0 } // IsPluginInUse indicates an expected call of IsPluginInUse. -func (mr *MockPluginServiceMockRecorder) IsPluginInUse(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) IsPluginInUse(pluginName any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPluginInUse", reflect.TypeOf((*MockPluginService)(nil).IsPluginInUse), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPluginInUse", reflect.TypeOf((*MockPluginService)(nil).IsPluginInUse), pluginName) } // IsStaticHostListProvider mocks base method. @@ -614,17 +619,17 @@ func (mr *MockPluginServiceMockRecorder) IsStaticHostListProvider() *gomock.Call } // RefreshHostList mocks base method. -func (m *MockPluginService) RefreshHostList(arg0 driver.Conn) error { +func (m *MockPluginService) RefreshHostList(conn driver.Conn) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RefreshHostList", arg0) + ret := m.ctrl.Call(m, "RefreshHostList", conn) ret0, _ := ret[0].(error) return ret0 } // RefreshHostList indicates an expected call of RefreshHostList. -func (mr *MockPluginServiceMockRecorder) RefreshHostList(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) RefreshHostList(conn any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshHostList", reflect.TypeOf((*MockPluginService)(nil).RefreshHostList), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshHostList", reflect.TypeOf((*MockPluginService)(nil).RefreshHostList), conn) } // ResetSession mocks base method. @@ -639,42 +644,54 @@ func (mr *MockPluginServiceMockRecorder) ResetSession() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetSession", reflect.TypeOf((*MockPluginService)(nil).ResetSession)) } +// SetAllowedAndBlockedHosts mocks base method. +func (m *MockPluginService) SetAllowedAndBlockedHosts(allowedAndBlockedHosts *driver_infrastructure.AllowedAndBlockedHosts) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAllowedAndBlockedHosts", allowedAndBlockedHosts) +} + +// SetAllowedAndBlockedHosts indicates an expected call of SetAllowedAndBlockedHosts. +func (mr *MockPluginServiceMockRecorder) SetAllowedAndBlockedHosts(allowedAndBlockedHosts any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAllowedAndBlockedHosts", reflect.TypeOf((*MockPluginService)(nil).SetAllowedAndBlockedHosts), allowedAndBlockedHosts) +} + // SetAvailability mocks base method. -func (m *MockPluginService) SetAvailability(arg0 map[string]bool, arg1 host_info_util.HostAvailability) { +func (m *MockPluginService) SetAvailability(hostAliases map[string]bool, availability host_info_util.HostAvailability) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetAvailability", arg0, arg1) + m.ctrl.Call(m, "SetAvailability", hostAliases, availability) } // SetAvailability indicates an expected call of SetAvailability. -func (mr *MockPluginServiceMockRecorder) SetAvailability(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) SetAvailability(hostAliases, availability any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAvailability", reflect.TypeOf((*MockPluginService)(nil).SetAvailability), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAvailability", reflect.TypeOf((*MockPluginService)(nil).SetAvailability), hostAliases, availability) } // SetBgStatus mocks base method. -func (m *MockPluginService) SetBgStatus(arg0 driver_infrastructure.BlueGreenStatus, arg1 string) { +func (m *MockPluginService) SetBgStatus(status driver_infrastructure.BlueGreenStatus, id string) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetBgStatus", arg0, arg1) + m.ctrl.Call(m, "SetBgStatus", status, id) } // SetBgStatus indicates an expected call of SetBgStatus. -func (mr *MockPluginServiceMockRecorder) SetBgStatus(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) SetBgStatus(status, id any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBgStatus", reflect.TypeOf((*MockPluginService)(nil).SetBgStatus), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBgStatus", reflect.TypeOf((*MockPluginService)(nil).SetBgStatus), status, id) } // SetCurrentConnection mocks base method. -func (m *MockPluginService) SetCurrentConnection(arg0 driver.Conn, arg1 *host_info_util.HostInfo, arg2 driver_infrastructure.ConnectionPlugin) error { +func (m *MockPluginService) SetCurrentConnection(conn driver.Conn, hostInfo *host_info_util.HostInfo, skipNotificationForThisPlugin driver_infrastructure.ConnectionPlugin) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetCurrentConnection", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "SetCurrentConnection", conn, hostInfo, skipNotificationForThisPlugin) ret0, _ := ret[0].(error) return ret0 } // SetCurrentConnection indicates an expected call of SetCurrentConnection. -func (mr *MockPluginServiceMockRecorder) SetCurrentConnection(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) SetCurrentConnection(conn, hostInfo, skipNotificationForThisPlugin any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCurrentConnection", reflect.TypeOf((*MockPluginService)(nil).SetCurrentConnection), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCurrentConnection", reflect.TypeOf((*MockPluginService)(nil).SetCurrentConnection), conn, hostInfo, skipNotificationForThisPlugin) } // SetCurrentTx mocks base method. @@ -684,97 +701,97 @@ func (m *MockPluginService) SetCurrentTx(arg0 driver.Tx) { } // SetCurrentTx indicates an expected call of SetCurrentTx. -func (mr *MockPluginServiceMockRecorder) SetCurrentTx(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) SetCurrentTx(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCurrentTx", reflect.TypeOf((*MockPluginService)(nil).SetCurrentTx), arg0) } // SetDialect mocks base method. -func (m *MockPluginService) SetDialect(arg0 driver_infrastructure.DatabaseDialect) { +func (m *MockPluginService) SetDialect(dialect driver_infrastructure.DatabaseDialect) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetDialect", arg0) + m.ctrl.Call(m, "SetDialect", dialect) } // SetDialect indicates an expected call of SetDialect. -func (mr *MockPluginServiceMockRecorder) SetDialect(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) SetDialect(dialect any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDialect", reflect.TypeOf((*MockPluginService)(nil).SetDialect), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDialect", reflect.TypeOf((*MockPluginService)(nil).SetDialect), dialect) } // SetHostListProvider mocks base method. -func (m *MockPluginService) SetHostListProvider(arg0 driver_infrastructure.HostListProvider) { +func (m *MockPluginService) SetHostListProvider(hostListProvider driver_infrastructure.HostListProvider) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetHostListProvider", arg0) + m.ctrl.Call(m, "SetHostListProvider", hostListProvider) } // SetHostListProvider indicates an expected call of SetHostListProvider. -func (mr *MockPluginServiceMockRecorder) SetHostListProvider(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) SetHostListProvider(hostListProvider any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHostListProvider", reflect.TypeOf((*MockPluginService)(nil).SetHostListProvider), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHostListProvider", reflect.TypeOf((*MockPluginService)(nil).SetHostListProvider), hostListProvider) } // SetInTransaction mocks base method. -func (m *MockPluginService) SetInTransaction(arg0 bool) { +func (m *MockPluginService) SetInTransaction(inTransaction bool) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetInTransaction", arg0) + m.ctrl.Call(m, "SetInTransaction", inTransaction) } // SetInTransaction indicates an expected call of SetInTransaction. -func (mr *MockPluginServiceMockRecorder) SetInTransaction(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) SetInTransaction(inTransaction any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetInTransaction", reflect.TypeOf((*MockPluginService)(nil).SetInTransaction), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetInTransaction", reflect.TypeOf((*MockPluginService)(nil).SetInTransaction), inTransaction) } // SetInitialConnectionHostInfo mocks base method. -func (m *MockPluginService) SetInitialConnectionHostInfo(arg0 *host_info_util.HostInfo) { +func (m *MockPluginService) SetInitialConnectionHostInfo(info *host_info_util.HostInfo) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetInitialConnectionHostInfo", arg0) + m.ctrl.Call(m, "SetInitialConnectionHostInfo", info) } // SetInitialConnectionHostInfo indicates an expected call of SetInitialConnectionHostInfo. -func (mr *MockPluginServiceMockRecorder) SetInitialConnectionHostInfo(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) SetInitialConnectionHostInfo(info any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetInitialConnectionHostInfo", reflect.TypeOf((*MockPluginService)(nil).SetInitialConnectionHostInfo), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetInitialConnectionHostInfo", reflect.TypeOf((*MockPluginService)(nil).SetInitialConnectionHostInfo), info) } // SetTelemetryContext mocks base method. -func (m *MockPluginService) SetTelemetryContext(arg0 context.Context) { +func (m *MockPluginService) SetTelemetryContext(ctx context.Context) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetTelemetryContext", arg0) + m.ctrl.Call(m, "SetTelemetryContext", ctx) } // SetTelemetryContext indicates an expected call of SetTelemetryContext. -func (mr *MockPluginServiceMockRecorder) SetTelemetryContext(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) SetTelemetryContext(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTelemetryContext", reflect.TypeOf((*MockPluginService)(nil).SetTelemetryContext), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTelemetryContext", reflect.TypeOf((*MockPluginService)(nil).SetTelemetryContext), ctx) } // UpdateDialect mocks base method. -func (m *MockPluginService) UpdateDialect(arg0 driver.Conn) { +func (m *MockPluginService) UpdateDialect(conn driver.Conn) { m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateDialect", arg0) + m.ctrl.Call(m, "UpdateDialect", conn) } // UpdateDialect indicates an expected call of UpdateDialect. -func (mr *MockPluginServiceMockRecorder) UpdateDialect(arg0 interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) UpdateDialect(conn any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateDialect", reflect.TypeOf((*MockPluginService)(nil).UpdateDialect), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateDialect", reflect.TypeOf((*MockPluginService)(nil).UpdateDialect), conn) } // UpdateState mocks base method. -func (m *MockPluginService) UpdateState(arg0 string, arg1 ...interface{}) { +func (m *MockPluginService) UpdateState(sql string, methodArgs ...any) { m.ctrl.T.Helper() - varargs := []interface{}{arg0} - for _, a := range arg1 { + varargs := []any{sql} + for _, a := range methodArgs { varargs = append(varargs, a) } m.ctrl.Call(m, "UpdateState", varargs...) } // UpdateState indicates an expected call of UpdateState. -func (mr *MockPluginServiceMockRecorder) UpdateState(arg0 interface{}, arg1 ...interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) UpdateState(sql any, methodArgs ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0}, arg1...) + varargs := append([]any{sql}, methodArgs...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateState", reflect.TypeOf((*MockPluginService)(nil).UpdateState), varargs...) } @@ -782,6 +799,7 @@ func (mr *MockPluginServiceMockRecorder) UpdateState(arg0 interface{}, arg1 ...i type MockPluginManager struct { ctrl *gomock.Controller recorder *MockPluginManagerMockRecorder + isgomock struct{} } // MockPluginManagerMockRecorder is the mock recorder for MockPluginManager. @@ -802,69 +820,69 @@ func (m *MockPluginManager) EXPECT() *MockPluginManagerMockRecorder { } // AcceptsStrategy mocks base method. -func (m *MockPluginManager) AcceptsStrategy(arg0 string) bool { +func (m *MockPluginManager) AcceptsStrategy(strategy string) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcceptsStrategy", arg0) + ret := m.ctrl.Call(m, "AcceptsStrategy", strategy) ret0, _ := ret[0].(bool) return ret0 } // AcceptsStrategy indicates an expected call of AcceptsStrategy. -func (mr *MockPluginManagerMockRecorder) AcceptsStrategy(arg0 interface{}) *gomock.Call { +func (mr *MockPluginManagerMockRecorder) AcceptsStrategy(strategy any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptsStrategy", reflect.TypeOf((*MockPluginManager)(nil).AcceptsStrategy), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptsStrategy", reflect.TypeOf((*MockPluginManager)(nil).AcceptsStrategy), strategy) } // Connect mocks base method. -func (m *MockPluginManager) Connect(arg0 *host_info_util.HostInfo, arg1 *utils.RWMap[string, string], arg2 bool, arg3 driver_infrastructure.ConnectionPlugin) (driver.Conn, error) { +func (m *MockPluginManager) Connect(hostInfo *host_info_util.HostInfo, props *utils.RWMap[string, string], isInitialConnection bool, pluginToSkip driver_infrastructure.ConnectionPlugin) (driver.Conn, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Connect", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "Connect", hostInfo, props, isInitialConnection, pluginToSkip) ret0, _ := ret[0].(driver.Conn) ret1, _ := ret[1].(error) return ret0, ret1 } // Connect indicates an expected call of Connect. -func (mr *MockPluginManagerMockRecorder) Connect(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockPluginManagerMockRecorder) Connect(hostInfo, props, isInitialConnection, pluginToSkip any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockPluginManager)(nil).Connect), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockPluginManager)(nil).Connect), hostInfo, props, isInitialConnection, pluginToSkip) } // Execute mocks base method. -func (m *MockPluginManager) Execute(arg0 driver.Conn, arg1 string, arg2 driver_infrastructure.ExecuteFunc, arg3 ...interface{}) (interface{}, interface{}, bool, error) { +func (m *MockPluginManager) Execute(connInvokedOn driver.Conn, name string, methodFunc driver_infrastructure.ExecuteFunc, methodArgs ...any) (any, any, bool, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { + varargs := []any{connInvokedOn, name, methodFunc} + for _, a := range methodArgs { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Execute", varargs...) - ret0, _ := ret[0].(interface{}) - ret1, _ := ret[1].(interface{}) + ret0, _ := ret[0].(any) + ret1, _ := ret[1].(any) ret2, _ := ret[2].(bool) ret3, _ := ret[3].(error) return ret0, ret1, ret2, ret3 } // Execute indicates an expected call of Execute. -func (mr *MockPluginManagerMockRecorder) Execute(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockPluginManagerMockRecorder) Execute(connInvokedOn, name, methodFunc any, methodArgs ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + varargs := append([]any{connInvokedOn, name, methodFunc}, methodArgs...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Execute", reflect.TypeOf((*MockPluginManager)(nil).Execute), varargs...) } // ForceConnect mocks base method. -func (m *MockPluginManager) ForceConnect(arg0 *host_info_util.HostInfo, arg1 *utils.RWMap[string, string], arg2 bool) (driver.Conn, error) { +func (m *MockPluginManager) ForceConnect(hostInfo *host_info_util.HostInfo, props *utils.RWMap[string, string], isInitialConnection bool) (driver.Conn, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ForceConnect", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "ForceConnect", hostInfo, props, isInitialConnection) ret0, _ := ret[0].(driver.Conn) ret1, _ := ret[1].(error) return ret0, ret1 } // ForceConnect indicates an expected call of ForceConnect. -func (mr *MockPluginManagerMockRecorder) ForceConnect(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPluginManagerMockRecorder) ForceConnect(hostInfo, props, isInitialConnection any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForceConnect", reflect.TypeOf((*MockPluginManager)(nil).ForceConnect), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForceConnect", reflect.TypeOf((*MockPluginManager)(nil).ForceConnect), hostInfo, props, isInitialConnection) } // GetConnectionProviderManager mocks base method. @@ -910,33 +928,33 @@ func (mr *MockPluginManagerMockRecorder) GetEffectiveConnectionProvider() *gomoc } // GetHostInfoByStrategy mocks base method. -func (m *MockPluginManager) GetHostInfoByStrategy(arg0 host_info_util.HostRole, arg1 string, arg2 []*host_info_util.HostInfo) (*host_info_util.HostInfo, error) { +func (m *MockPluginManager) GetHostInfoByStrategy(role host_info_util.HostRole, strategy string, hosts []*host_info_util.HostInfo) (*host_info_util.HostInfo, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHostInfoByStrategy", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetHostInfoByStrategy", role, strategy, hosts) ret0, _ := ret[0].(*host_info_util.HostInfo) ret1, _ := ret[1].(error) return ret0, ret1 } // GetHostInfoByStrategy indicates an expected call of GetHostInfoByStrategy. -func (mr *MockPluginManagerMockRecorder) GetHostInfoByStrategy(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPluginManagerMockRecorder) GetHostInfoByStrategy(role, strategy, hosts any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostInfoByStrategy", reflect.TypeOf((*MockPluginManager)(nil).GetHostInfoByStrategy), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostInfoByStrategy", reflect.TypeOf((*MockPluginManager)(nil).GetHostInfoByStrategy), role, strategy, hosts) } // GetHostSelectorStrategy mocks base method. -func (m *MockPluginManager) GetHostSelectorStrategy(arg0 string) (driver_infrastructure.HostSelector, error) { +func (m *MockPluginManager) GetHostSelectorStrategy(strategy string) (driver_infrastructure.HostSelector, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHostSelectorStrategy", arg0) + ret := m.ctrl.Call(m, "GetHostSelectorStrategy", strategy) ret0, _ := ret[0].(driver_infrastructure.HostSelector) ret1, _ := ret[1].(error) return ret0, ret1 } // GetHostSelectorStrategy indicates an expected call of GetHostSelectorStrategy. -func (mr *MockPluginManagerMockRecorder) GetHostSelectorStrategy(arg0 interface{}) *gomock.Call { +func (mr *MockPluginManagerMockRecorder) GetHostSelectorStrategy(strategy any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostSelectorStrategy", reflect.TypeOf((*MockPluginManager)(nil).GetHostSelectorStrategy), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostSelectorStrategy", reflect.TypeOf((*MockPluginManager)(nil).GetHostSelectorStrategy), strategy) } // GetTelemetryContext mocks base method. @@ -968,85 +986,85 @@ func (mr *MockPluginManagerMockRecorder) GetTelemetryFactory() *gomock.Call { } // Init mocks base method. -func (m *MockPluginManager) Init(arg0 driver_infrastructure.PluginService, arg1 []driver_infrastructure.ConnectionPlugin) error { +func (m *MockPluginManager) Init(pluginService driver_infrastructure.PluginService, plugins []driver_infrastructure.ConnectionPlugin) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Init", arg0, arg1) + ret := m.ctrl.Call(m, "Init", pluginService, plugins) ret0, _ := ret[0].(error) return ret0 } // Init indicates an expected call of Init. -func (mr *MockPluginManagerMockRecorder) Init(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPluginManagerMockRecorder) Init(pluginService, plugins any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockPluginManager)(nil).Init), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockPluginManager)(nil).Init), pluginService, plugins) } // InitHostProvider mocks base method. -func (m *MockPluginManager) InitHostProvider(arg0 *utils.RWMap[string, string], arg1 driver_infrastructure.HostListProviderService) error { +func (m *MockPluginManager) InitHostProvider(props *utils.RWMap[string, string], hostListProviderService driver_infrastructure.HostListProviderService) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InitHostProvider", arg0, arg1) + ret := m.ctrl.Call(m, "InitHostProvider", props, hostListProviderService) ret0, _ := ret[0].(error) return ret0 } // InitHostProvider indicates an expected call of InitHostProvider. -func (mr *MockPluginManagerMockRecorder) InitHostProvider(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPluginManagerMockRecorder) InitHostProvider(props, hostListProviderService any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitHostProvider", reflect.TypeOf((*MockPluginManager)(nil).InitHostProvider), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitHostProvider", reflect.TypeOf((*MockPluginManager)(nil).InitHostProvider), props, hostListProviderService) } // IsPluginInUse mocks base method. -func (m *MockPluginManager) IsPluginInUse(arg0 string) bool { +func (m *MockPluginManager) IsPluginInUse(pluginName string) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsPluginInUse", arg0) + ret := m.ctrl.Call(m, "IsPluginInUse", pluginName) ret0, _ := ret[0].(bool) return ret0 } // IsPluginInUse indicates an expected call of IsPluginInUse. -func (mr *MockPluginManagerMockRecorder) IsPluginInUse(arg0 interface{}) *gomock.Call { +func (mr *MockPluginManagerMockRecorder) IsPluginInUse(pluginName any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPluginInUse", reflect.TypeOf((*MockPluginManager)(nil).IsPluginInUse), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPluginInUse", reflect.TypeOf((*MockPluginManager)(nil).IsPluginInUse), pluginName) } // NotifyConnectionChanged mocks base method. -func (m *MockPluginManager) NotifyConnectionChanged(arg0 map[driver_infrastructure.HostChangeOptions]bool, arg1 driver_infrastructure.ConnectionPlugin) map[driver_infrastructure.OldConnectionSuggestedAction]bool { +func (m *MockPluginManager) NotifyConnectionChanged(changes map[driver_infrastructure.HostChangeOptions]bool, skipNotificationForThisPlugin driver_infrastructure.ConnectionPlugin) map[driver_infrastructure.OldConnectionSuggestedAction]bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NotifyConnectionChanged", arg0, arg1) + ret := m.ctrl.Call(m, "NotifyConnectionChanged", changes, skipNotificationForThisPlugin) ret0, _ := ret[0].(map[driver_infrastructure.OldConnectionSuggestedAction]bool) return ret0 } // NotifyConnectionChanged indicates an expected call of NotifyConnectionChanged. -func (mr *MockPluginManagerMockRecorder) NotifyConnectionChanged(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPluginManagerMockRecorder) NotifyConnectionChanged(changes, skipNotificationForThisPlugin any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotifyConnectionChanged", reflect.TypeOf((*MockPluginManager)(nil).NotifyConnectionChanged), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotifyConnectionChanged", reflect.TypeOf((*MockPluginManager)(nil).NotifyConnectionChanged), changes, skipNotificationForThisPlugin) } // NotifyHostListChanged mocks base method. -func (m *MockPluginManager) NotifyHostListChanged(arg0 map[string]map[driver_infrastructure.HostChangeOptions]bool) { +func (m *MockPluginManager) NotifyHostListChanged(changes map[string]map[driver_infrastructure.HostChangeOptions]bool) { m.ctrl.T.Helper() - m.ctrl.Call(m, "NotifyHostListChanged", arg0) + m.ctrl.Call(m, "NotifyHostListChanged", changes) } // NotifyHostListChanged indicates an expected call of NotifyHostListChanged. -func (mr *MockPluginManagerMockRecorder) NotifyHostListChanged(arg0 interface{}) *gomock.Call { +func (mr *MockPluginManagerMockRecorder) NotifyHostListChanged(changes any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotifyHostListChanged", reflect.TypeOf((*MockPluginManager)(nil).NotifyHostListChanged), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotifyHostListChanged", reflect.TypeOf((*MockPluginManager)(nil).NotifyHostListChanged), changes) } // NotifySubscribedPlugins mocks base method. -func (m *MockPluginManager) NotifySubscribedPlugins(arg0 string, arg1 driver_infrastructure.PluginExecFunc, arg2 driver_infrastructure.ConnectionPlugin) error { +func (m *MockPluginManager) NotifySubscribedPlugins(methodName string, pluginFunc driver_infrastructure.PluginExecFunc, skipNotificationForThisPlugin driver_infrastructure.ConnectionPlugin) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NotifySubscribedPlugins", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "NotifySubscribedPlugins", methodName, pluginFunc, skipNotificationForThisPlugin) ret0, _ := ret[0].(error) return ret0 } // NotifySubscribedPlugins indicates an expected call of NotifySubscribedPlugins. -func (mr *MockPluginManagerMockRecorder) NotifySubscribedPlugins(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockPluginManagerMockRecorder) NotifySubscribedPlugins(methodName, pluginFunc, skipNotificationForThisPlugin any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotifySubscribedPlugins", reflect.TypeOf((*MockPluginManager)(nil).NotifySubscribedPlugins), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotifySubscribedPlugins", reflect.TypeOf((*MockPluginManager)(nil).NotifySubscribedPlugins), methodName, pluginFunc, skipNotificationForThisPlugin) } // ReleaseResources mocks base method. @@ -1062,35 +1080,36 @@ func (mr *MockPluginManagerMockRecorder) ReleaseResources() *gomock.Call { } // SetTelemetryContext mocks base method. -func (m *MockPluginManager) SetTelemetryContext(arg0 context.Context) { +func (m *MockPluginManager) SetTelemetryContext(ctx context.Context) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetTelemetryContext", arg0) + m.ctrl.Call(m, "SetTelemetryContext", ctx) } // SetTelemetryContext indicates an expected call of SetTelemetryContext. -func (mr *MockPluginManagerMockRecorder) SetTelemetryContext(arg0 interface{}) *gomock.Call { +func (mr *MockPluginManagerMockRecorder) SetTelemetryContext(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTelemetryContext", reflect.TypeOf((*MockPluginManager)(nil).SetTelemetryContext), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTelemetryContext", reflect.TypeOf((*MockPluginManager)(nil).SetTelemetryContext), ctx) } // UnwrapPlugin mocks base method. -func (m *MockPluginManager) UnwrapPlugin(arg0 string) driver_infrastructure.ConnectionPlugin { +func (m *MockPluginManager) UnwrapPlugin(pluginCode string) driver_infrastructure.ConnectionPlugin { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UnwrapPlugin", arg0) + ret := m.ctrl.Call(m, "UnwrapPlugin", pluginCode) ret0, _ := ret[0].(driver_infrastructure.ConnectionPlugin) return ret0 } // UnwrapPlugin indicates an expected call of UnwrapPlugin. -func (mr *MockPluginManagerMockRecorder) UnwrapPlugin(arg0 interface{}) *gomock.Call { +func (mr *MockPluginManagerMockRecorder) UnwrapPlugin(pluginCode any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnwrapPlugin", reflect.TypeOf((*MockPluginManager)(nil).UnwrapPlugin), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnwrapPlugin", reflect.TypeOf((*MockPluginManager)(nil).UnwrapPlugin), pluginCode) } // MockCanReleaseResources is a mock of CanReleaseResources interface. type MockCanReleaseResources struct { ctrl *gomock.Controller recorder *MockCanReleaseResourcesMockRecorder + isgomock struct{} } // MockCanReleaseResourcesMockRecorder is the mock recorder for MockCanReleaseResources. diff --git a/.test/test/mocks/custom-endpoint/mock_custom_endpoint_monitor.go b/.test/test/mocks/custom-endpoint/mock_custom_endpoint_monitor.go new file mode 100644 index 00000000..cfff14de --- /dev/null +++ b/.test/test/mocks/custom-endpoint/mock_custom_endpoint_monitor.go @@ -0,0 +1,96 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Code generated by MockGen. DO NOT EDIT. +// Source: custom-endpoint/custom_endpoint_monitor.go +// +// Generated by this command: +// +// mockgen -source=custom-endpoint/custom_endpoint_monitor.go -destination=.test/test/mocks/custom-endpoint/mock_custom_endpoint_monitor.go package=mock_custom_endpoint +// + +// Package mock_custom_endpoint is a generated GoMock package. +package mock_custom_endpoint + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockCustomEndpointMonitor is a mock of CustomEndpointMonitor interface. +type MockCustomEndpointMonitor struct { + ctrl *gomock.Controller + recorder *MockCustomEndpointMonitorMockRecorder + isgomock struct{} +} + +// MockCustomEndpointMonitorMockRecorder is the mock recorder for MockCustomEndpointMonitor. +type MockCustomEndpointMonitorMockRecorder struct { + mock *MockCustomEndpointMonitor +} + +// NewMockCustomEndpointMonitor creates a new mock instance. +func NewMockCustomEndpointMonitor(ctrl *gomock.Controller) *MockCustomEndpointMonitor { + mock := &MockCustomEndpointMonitor{ctrl: ctrl} + mock.recorder = &MockCustomEndpointMonitorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCustomEndpointMonitor) EXPECT() *MockCustomEndpointMonitorMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockCustomEndpointMonitor) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close. +func (mr *MockCustomEndpointMonitorMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCustomEndpointMonitor)(nil).Close)) +} + +// HasCustomEndpointInfo mocks base method. +func (m *MockCustomEndpointMonitor) HasCustomEndpointInfo() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasCustomEndpointInfo") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasCustomEndpointInfo indicates an expected call of HasCustomEndpointInfo. +func (mr *MockCustomEndpointMonitorMockRecorder) HasCustomEndpointInfo() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasCustomEndpointInfo", reflect.TypeOf((*MockCustomEndpointMonitor)(nil).HasCustomEndpointInfo)) +} + +// ShouldDispose mocks base method. +func (m *MockCustomEndpointMonitor) ShouldDispose() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ShouldDispose") + ret0, _ := ret[0].(bool) + return ret0 +} + +// ShouldDispose indicates an expected call of ShouldDispose. +func (mr *MockCustomEndpointMonitorMockRecorder) ShouldDispose() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldDispose", reflect.TypeOf((*MockCustomEndpointMonitor)(nil).ShouldDispose)) +} diff --git a/awssql/driver_infrastructure/allowed_and_blocked_hosts.go b/awssql/driver_infrastructure/allowed_and_blocked_hosts.go index 81c80536..29ab38b1 100644 --- a/awssql/driver_infrastructure/allowed_and_blocked_hosts.go +++ b/awssql/driver_infrastructure/allowed_and_blocked_hosts.go @@ -27,10 +27,10 @@ func NewAllowedAndBlockedHosts( var allowedHostIdsToSet map[string]bool var blockedHostIdsToSet map[string]bool - if allowedHostIds != nil && len(allowedHostIds) > 0 { + if len(allowedHostIds) > 0 { allowedHostIdsToSet = allowedHostIds } - if blockedHostIds != nil && len(blockedHostIds) > 0 { + if len(blockedHostIds) > 0 { blockedHostIdsToSet = blockedHostIds } return &AllowedAndBlockedHosts{ diff --git a/awssql/host_info_util/host_info_util.go b/awssql/host_info_util/host_info_util.go index f5a45003..e0a13805 100644 --- a/awssql/host_info_util/host_info_util.go +++ b/awssql/host_info_util/host_info_util.go @@ -87,6 +87,9 @@ func HaveNoHostsInCommon(hosts1 []*HostInfo, hosts2 []*HostInfo) bool { } func IsHostInList(host *HostInfo, hosts []*HostInfo) bool { + if host == nil { + return false + } if len(hosts) < 1 { return false } diff --git a/awssql/plugin_helpers/plugin_service.go b/awssql/plugin_helpers/plugin_service.go index a4299304..5b5dc5a2 100644 --- a/awssql/plugin_helpers/plugin_service.go +++ b/awssql/plugin_helpers/plugin_service.go @@ -250,17 +250,12 @@ func (p *PluginServiceImpl) GetCurrentHostInfo() (*host_info_util.HostInfo, erro p.currentHostInfo = host_info_util.GetWriter(p.AllHosts) allowedHosts := p.GetHosts() - if !host_info_util.IsHostInList(p.currentHostInfo, allowedHosts) { - if p.currentHostInfo == nil { - return nil, error_util.NewGenericAwsWrapperError( - error_util.GetMessage("PluginServiceImpl.currentHostNotAllowed", p.currentHostInfo.GetHostAndPort(), utils.LogTopology(allowedHosts, ""))) - } else { - return nil, error_util.NewGenericAwsWrapperError( - error_util.GetMessage("PluginServiceImpl.currentHostNotAllowed", "", utils.LogTopology(allowedHosts, ""))) - } + if p.currentHostInfo != nil && !host_info_util.IsHostInList(p.currentHostInfo, allowedHosts) { + return nil, error_util.NewGenericAwsWrapperError( + error_util.GetMessage("PluginServiceImpl.currentHostNotAllowed", p.currentHostInfo.GetHostAndPort(), utils.LogTopology(allowedHosts, ""))) } - if p.currentHostInfo.IsNil() { + if p.currentHostInfo == nil || p.currentHostInfo.IsNil() { p.currentHostInfo = p.AllHosts[0] } } @@ -287,14 +282,14 @@ func (p *PluginServiceImpl) GetHosts() []*host_info_util.HostInfo { allowedHosts := p.allowedAndBlockedHosts.Load().GetAllowedHostIds() blockedHosts := p.allowedAndBlockedHosts.Load().GetBlockedHostIds() - if allowedHosts != nil && len(allowedHosts) > 0 { + if len(allowedHosts) > 0 { hosts = utils.FilterSlice(hosts, func(item *host_info_util.HostInfo) bool { value, ok := allowedHosts[item.HostId] return ok && value }) } - if blockedHosts != nil && len(blockedHosts) > 0 { + if len(blockedHosts) > 0 { hosts = utils.FilterSlice(hosts, func(item *host_info_util.HostInfo) bool { value, ok := blockedHosts[item.HostId] return !ok || !value diff --git a/custom-endpoint/custom_endpoint_monitor.go b/custom-endpoint/custom_endpoint_monitor.go index 085eb6e5..8bbee598 100644 --- a/custom-endpoint/custom_endpoint_monitor.go +++ b/custom-endpoint/custom_endpoint_monitor.go @@ -61,7 +61,6 @@ func NewCustomEndpointMonitorImpl( region region_util.Region, refreshRateMs time.Duration, rdsClient *rds.Client) *CustomEndpointMonitorImpl { - monitor := &CustomEndpointMonitorImpl{ pluginService: pluginService, customEndpointHostInfo: customEndpointHostInfo, @@ -125,7 +124,7 @@ func (monitor *CustomEndpointMonitorImpl) run() { cachedEndpointInfo, ok := customEndpointInfoCache.Get(monitor.getCustomEndpointInfoCacheKey()) if ok && endpointInfo.Equals(cachedEndpointInfo) { - elapsedTime := time.Now().Sub(start) + elapsedTime := time.Since(start) sleepDuration := monitor.refreshRateMs - elapsedTime if sleepDuration < 0 { sleepDuration = 0 @@ -149,7 +148,7 @@ func (monitor *CustomEndpointMonitorImpl) run() { customEndpointInfoCache.Put(monitor.customEndpointHostInfo.GetHost(), endpointInfo, CUSTOM_ENDPOINT_INFO_EXPIRATION_NANO) - elapsedTime := time.Now().Sub(start) + elapsedTime := time.Since(start) sleepDuration := monitor.refreshRateMs - elapsedTime if sleepDuration < 0 { sleepDuration = 0 diff --git a/custom-endpoint/custom_endpoint_plugin.go b/custom-endpoint/custom_endpoint_plugin.go index 6a4bcc4b..49c7eaeb 100644 --- a/custom-endpoint/custom_endpoint_plugin.go +++ b/custom-endpoint/custom_endpoint_plugin.go @@ -52,7 +52,6 @@ type getRdsClientFunc func(*host_info_util.HostInfo, *utils.RWMap[string, string func (factory CustomEndpointPluginFactory) GetInstance( pluginService driver_infrastructure.PluginService, props *utils.RWMap[string, string]) (driver_infrastructure.ConnectionPlugin, error) { - return NewCustomEndpointPlugin(pluginService, getRdsClientFuncImpl, props) } @@ -76,18 +75,15 @@ func getRdsClientFuncImpl(hostInfo *host_info_util.HostInfo, props *utils.RWMap[ return rdsClient, nil } -func (factory CustomEndpointPluginFactory) ClearCaches() {} +func (factory CustomEndpointPluginFactory) ClearCaches() { + CUSTOM_ENDPOINT_MONITORS.CleanUp() +} func NewCustomEndpointPluginFactory() driver_infrastructure.ConnectionPluginFactory { return CustomEndpointPluginFactory{} } -var monitorDisposalFunc utils.DisposalFunc[CustomEndpointMonitor] = func(item CustomEndpointMonitor) bool { - item.Close() - return true -} -var monitors = utils.NewSlidingExpirationCache[CustomEndpointMonitor]( - "custom-endpoint-monitor", monitorDisposalFunc) +var CUSTOM_ENDPOINT_MONITORS *utils.SlidingExpirationCache[CustomEndpointMonitor] type CustomEndpointPlugin struct { plugins.BaseConnectionPlugin @@ -107,12 +103,20 @@ func NewCustomEndpointPlugin( pluginService driver_infrastructure.PluginService, rdsClientFunc getRdsClientFunc, props *utils.RWMap[string, string]) (*CustomEndpointPlugin, error) { - waitForInfoCounter, err := pluginService.GetTelemetryFactory().CreateCounter(TELEMETRY_WAIT_FOR_INFO_COUNTER) if err != nil { return nil, err } + if CUSTOM_ENDPOINT_MONITORS == nil { + CUSTOM_ENDPOINT_MONITORS = utils.NewSlidingExpirationCache( + "custom-endpoint-monitor", + func(item CustomEndpointMonitor) bool { + item.Close() + return true + }) + } + return &CustomEndpointPlugin{ pluginService: pluginService, props: props, @@ -124,6 +128,17 @@ func NewCustomEndpointPlugin( }, nil } +// NOTE: This method is for testing purposes. +func NewCustomEndpointPluginWithHostInfo( + pluginService driver_infrastructure.PluginService, + rdsClientFunc getRdsClientFunc, + props *utils.RWMap[string, string], + customEndpointHostInfo *host_info_util.HostInfo) (*CustomEndpointPlugin, error) { + plugin, err := NewCustomEndpointPlugin(pluginService, rdsClientFunc, props) + plugin.customEndpointHostInfo = customEndpointHostInfo + return plugin, err +} + func (plugin *CustomEndpointPlugin) GetSubscribedMethods() []string { return append([]string{ plugin_helpers.CONNECT_METHOD, @@ -135,6 +150,9 @@ func (plugin *CustomEndpointPlugin) Connect( props *utils.RWMap[string, string], isInitialConnection bool, connectFunc driver_infrastructure.ConnectFunc) (driver.Conn, error) { + if !utils.IsRdsCustomClusterDns(hostInfo.GetHost()) { + return connectFunc(props) + } plugin.customEndpointHostInfo = hostInfo plugin.customEndpointId = utils.GetRdsClusterId(hostInfo.GetHost()) @@ -188,7 +206,7 @@ func (plugin *CustomEndpointPlugin) Execute( func (plugin *CustomEndpointPlugin) createMonitorIfAbsent( props *utils.RWMap[string, string]) (CustomEndpointMonitor, error) { refreshRateMs := time.Millisecond * time.Duration(property_util.GetRefreshRateValue(props, property_util.CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS)) - return monitors.ComputeIfAbsentWithError( + return CUSTOM_ENDPOINT_MONITORS.ComputeIfAbsentWithError( plugin.customEndpointHostInfo.Host, func() (CustomEndpointMonitor, error) { rdsClient, err := plugin.rdsClientFunc(plugin.customEndpointHostInfo, plugin.props) @@ -203,7 +221,7 @@ func (plugin *CustomEndpointPlugin) createMonitorIfAbsent( refreshRateMs, rdsClient, ), nil - }, 1) + }, time.Duration(plugin.idleMonitorExpirationMs)*time.Millisecond) } func (plugin *CustomEndpointPlugin) waitForCustomEndpointInfo(monitor CustomEndpointMonitor) error { From 3b210557576ccbe003d8584ba96007a515c6d920 Mon Sep 17 00:00:00 2001 From: Aaron Chung Date: Fri, 7 Nov 2025 15:25:52 -0800 Subject: [PATCH 3/3] custom-endpoints - address PR comments --- .../mock_plugin_helpers.go | 16 ++++++++++++++++ awssql/plugin_helpers/plugin_service.go | 4 ++-- custom-endpoint/custom_endpoint_monitor.go | 5 ++++- custom-endpoint/custom_endpoint_plugin.go | 19 +++++++++++++------ 4 files changed, 35 insertions(+), 9 deletions(-) diff --git a/.test/test/mocks/awssql/driver_infrastructure/mock_plugin_helpers.go b/.test/test/mocks/awssql/driver_infrastructure/mock_plugin_helpers.go index 1b8a8928..06073efe 100644 --- a/.test/test/mocks/awssql/driver_infrastructure/mock_plugin_helpers.go +++ b/.test/test/mocks/awssql/driver_infrastructure/mock_plugin_helpers.go @@ -1,3 +1,19 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + // Code generated by MockGen. DO NOT EDIT. // Source: awssql/driver_infrastructure/plugin_helpers.go // diff --git a/awssql/plugin_helpers/plugin_service.go b/awssql/plugin_helpers/plugin_service.go index 5b5dc5a2..a1dd7bd5 100644 --- a/awssql/plugin_helpers/plugin_service.go +++ b/awssql/plugin_helpers/plugin_service.go @@ -255,7 +255,7 @@ func (p *PluginServiceImpl) GetCurrentHostInfo() (*host_info_util.HostInfo, erro error_util.GetMessage("PluginServiceImpl.currentHostNotAllowed", p.currentHostInfo.GetHostAndPort(), utils.LogTopology(allowedHosts, ""))) } - if p.currentHostInfo == nil || p.currentHostInfo.IsNil() { + if p.currentHostInfo.IsNil() { p.currentHostInfo = p.AllHosts[0] } } @@ -267,7 +267,7 @@ func (p *PluginServiceImpl) GetCurrentHostInfo() (*host_info_util.HostInfo, erro return p.currentHostInfo, nil } -// TODO: transfer some uses of #GetHost to #GetAllHosts +// TODO: transfer some uses of #GetHost to #GetAllHosts. func (p *PluginServiceImpl) GetAllHosts() []*host_info_util.HostInfo { return p.AllHosts } diff --git a/custom-endpoint/custom_endpoint_monitor.go b/custom-endpoint/custom_endpoint_monitor.go index 8bbee598..66c4a7ea 100644 --- a/custom-endpoint/custom_endpoint_monitor.go +++ b/custom-endpoint/custom_endpoint_monitor.go @@ -60,6 +60,7 @@ func NewCustomEndpointMonitorImpl( endpointIdentifier string, region region_util.Region, refreshRateMs time.Duration, + infoChangedCounter telemetry.TelemetryCounter, rdsClient *rds.Client) *CustomEndpointMonitorImpl { monitor := &CustomEndpointMonitorImpl{ pluginService: pluginService, @@ -67,6 +68,7 @@ func NewCustomEndpointMonitorImpl( endpointIdentifier: endpointIdentifier, region: region, refreshRateMs: refreshRateMs, + infoChangedCounter: infoChangedCounter, rdsClient: rdsClient, } @@ -146,7 +148,8 @@ func (monitor *CustomEndpointMonitorImpl) run() { monitor.pluginService.SetAllowedAndBlockedHosts(allowedAndBlockedHosts) - customEndpointInfoCache.Put(monitor.customEndpointHostInfo.GetHost(), endpointInfo, CUSTOM_ENDPOINT_INFO_EXPIRATION_NANO) + customEndpointInfoCache.Put(monitor.getCustomEndpointInfoCacheKey(), endpointInfo, CUSTOM_ENDPOINT_INFO_EXPIRATION_NANO) + monitor.infoChangedCounter.Inc(monitor.pluginService.GetTelemetryContext()) elapsedTime := time.Since(start) sleepDuration := monitor.refreshRateMs - elapsedTime diff --git a/custom-endpoint/custom_endpoint_plugin.go b/custom-endpoint/custom_endpoint_plugin.go index 49c7eaeb..73c5316c 100644 --- a/custom-endpoint/custom_endpoint_plugin.go +++ b/custom-endpoint/custom_endpoint_plugin.go @@ -44,6 +44,7 @@ func init() { } const TELEMETRY_WAIT_FOR_INFO_COUNTER = "customEndpoint.waitForInfo.counter" +const TELEMETRY_ENDPOINT_INFO_CHANGED = "customEndpoint.infoChanged.counter" type CustomEndpointPluginFactory struct{} @@ -148,7 +149,7 @@ func (plugin *CustomEndpointPlugin) GetSubscribedMethods() []string { func (plugin *CustomEndpointPlugin) Connect( hostInfo *host_info_util.HostInfo, props *utils.RWMap[string, string], - isInitialConnection bool, + _ bool, connectFunc driver_infrastructure.ConnectFunc) (driver.Conn, error) { if !utils.IsRdsCustomClusterDns(hostInfo.GetHost()) { return connectFunc(props) @@ -213,32 +214,38 @@ func (plugin *CustomEndpointPlugin) createMonitorIfAbsent( if err != nil { return nil, err } + infoChangedCounter, err := plugin.pluginService.GetTelemetryFactory().CreateCounter(TELEMETRY_ENDPOINT_INFO_CHANGED) + if err != nil { + return nil, err + } + return NewCustomEndpointMonitorImpl( plugin.pluginService, plugin.customEndpointHostInfo, plugin.customEndpointId, plugin.region, refreshRateMs, + infoChangedCounter, rdsClient, ), nil }, time.Duration(plugin.idleMonitorExpirationMs)*time.Millisecond) } func (plugin *CustomEndpointPlugin) waitForCustomEndpointInfo(monitor CustomEndpointMonitor) error { - hasCustomEdnpointInfo := monitor.HasCustomEndpointInfo() + hasCustomEndpointInfo := monitor.HasCustomEndpointInfo() - if !hasCustomEdnpointInfo { + if !hasCustomEndpointInfo { if plugin.waitForInfoCounter != nil { plugin.waitForInfoCounter.Inc(plugin.pluginService.GetTelemetryContext()) } waitForEndpointInfoTimeout := time.Now().Add(time.Millisecond * time.Duration(plugin.waitOnCachedInfoDurationMs)) - for !hasCustomEdnpointInfo && time.Now().Before(waitForEndpointInfoTimeout) { + for !hasCustomEndpointInfo && time.Now().Before(waitForEndpointInfoTimeout) { time.Sleep(time.Millisecond * time.Duration(100)) - hasCustomEdnpointInfo = monitor.HasCustomEndpointInfo() + hasCustomEndpointInfo = monitor.HasCustomEndpointInfo() } - if !hasCustomEdnpointInfo { + if !hasCustomEndpointInfo { return errors.New(error_util.GetMessage("CustomEndpointPlugin.timedOutWaitingForCustomEndpointInfo")) } }