diff --git a/aws/logs_monitoring/steps/enrichment.py b/aws/logs_monitoring/steps/enrichment.py index f119fd01..b25be32f 100644 --- a/aws/logs_monitoring/steps/enrichment.py +++ b/aws/logs_monitoring/steps/enrichment.py @@ -161,12 +161,11 @@ def extract_ddtags_from_message(event): extracted_ddtags = extracted_ddtags.replace(" ", "") # Extract service tag from message.ddtags if exists - if "service:" in extracted_ddtags: - event[DD_SERVICE] = next( - tag[8:] - for tag in extracted_ddtags.split(",") - if tag.startswith("service:") - ) + service_tags = [ + tag for tag in extracted_ddtags.split(",") if tag.startswith("service:") + ] + if service_tags: + event[DD_SERVICE] = service_tags[0][8:] event[DD_CUSTOM_TAGS] = ",".join( [ tag diff --git a/aws/logs_monitoring/tests/test_enrichment.py b/aws/logs_monitoring/tests/test_enrichment.py index 88b1b3a6..d8b65be5 100644 --- a/aws/logs_monitoring/tests/test_enrichment.py +++ b/aws/logs_monitoring/tests/test_enrichment.py @@ -146,6 +146,31 @@ def test_extract_ddtags_handles_empty_spaces(self): "my_custom_service", ) + def test_extract_ddtags_service_substring_not_prefix(self): + """Test that service tag extraction handles cases where 'service:' appears + as a substring but not as a tag prefix (e.g., in 'myservice:api'). + + Before the fix (old code): + - "service:" in extracted_ddtags would be True + - But next() would raise StopIteration since no tag starts with "service:" + + After the fix (new code): + - List comprehension returns empty list when no tag starts with "service:" + - No exception is raised, service field is not set + """ + loaded_message_tags = {"ddtags": "env:prod,myservice:api,foo:bar"} + event = {"message": loaded_message_tags, "ddtags": "custom_tag:value"} + + extract_ddtags_from_message(event) + + # Service should not be set since no tag actually starts with "service:" + self.assertNotIn("service", event) + # All tags should be preserved in the correct order + self.assertEqual( + event["ddtags"], + "custom_tag:value,env:prod,myservice:api,foo:bar", + ) + class TestExtractHostFromLogEvents(unittest.TestCase): def test_parse_source_cloudtrail(self):