11import unittest
2+ import base64
23import json
34from unittest.mock import patch
45
56from datadog_lambda.dsm import (
67 set_dsm_context,
78 _dsm_set_sqs_context,
9+ _dsm_set_sns_context,
810 _get_dsm_context_from_lambda,
911)
1012from datadog_lambda.trigger import EventTypes, _EventSource
@@ -16,14 +18,18 @@ def setUp(self):
1618 self.mock_dsm_set_sqs_context = patcher.start()
1719 self.addCleanup(patcher.stop)
1820
19- # Patch set_consume_checkpoint for testing DSM functionality
2021 patcher = patch("ddtrace.data_streams.set_consume_checkpoint")
2122 self.mock_set_consume_checkpoint = patcher.start()
2223 self.addCleanup(patcher.stop)
2324
24- # Patch _get_dsm_context_from_lambda for testing DSM context extraction
2525 patcher = patch("datadog_lambda.dsm._get_dsm_context_from_lambda")
2626 self.mock_get_dsm_context_from_lambda = patcher.start()
27+ patcher = patch("datadog_lambda.dsm._dsm_set_sns_context")
28+ self.mock_dsm_set_sns_context = patcher.start()
29+ self.addCleanup(patcher.stop)
30+
31+ patcher = patch("ddtrace.internal.datastreams.data_streams_processor")
32+ self.mock_data_streams_processor = patcher.start()
2733 self.addCleanup(patcher.stop)
2834
2935 def test_non_sqs_event_source_does_nothing(self):
@@ -140,6 +146,123 @@ def test_sqs_multiple_records_process_each_record(self):
140146 pathway_ctx = carrier_get_func("dd-pathway-ctx-base64")
141147 self.assertEqual(pathway_ctx, expected_contexts[i])
142148
149+ def test_sns_event_with_no_records_does_nothing(self):
150+ """Test that events where Records is None don't trigger DSM processing"""
151+ events_with_no_records = [
152+ {},
153+ {"Records": None},
154+ {"someOtherField": "value"},
155+ ]
156+
157+ for event in events_with_no_records:
158+ _dsm_set_sns_context(event)
159+ self.mock_set_consume_checkpoint.assert_not_called()
160+
161+ def test_sns_event_triggers_dsm_sns_context(self):
162+ """Test that SNS event sources trigger the SNS-specific DSM context function"""
163+ sns_event = {
164+ "Records": [
165+ {
166+ "EventSource": "aws:sns",
167+ "Sns": {
168+ "TopicArn": "arn:aws:sns:us-east-1:123456789012:my-topic",
169+ "Message": "Hello from SNS!",
170+ },
171+ }
172+ ]
173+ }
174+
175+ event_source = _EventSource(EventTypes.SNS)
176+ set_dsm_context(sns_event, event_source)
177+
178+ self.mock_dsm_set_sns_context.assert_called_once_with(sns_event)
179+
180+ def test_sns_multiple_records_process_each_record(self):
181+ """Test that each record in an SNS event gets processed individually"""
182+ multi_record_event = {
183+ "Records": [
184+ {
185+ "EventSource": "aws:sns",
186+ "Sns": {
187+ "TopicArn": "arn:aws:sns:us-east-1:123456789012:topic1",
188+ "Message": "Message 1",
189+ "MessageAttributes": {
190+ "_datadog": {
191+ "Type": "Binary",
192+ "Value": base64.b64encode(
193+ json.dumps({"dd-pathway-ctx-base64": "context1"})
194+ .encode("utf-8")
195+ ).decode("utf-8")
196+ }
197+ },
198+ }
199+ },
200+ {
201+ "EventSource": "aws:sns",
202+ "Sns": {
203+ "TopicArn": "arn:aws:sns:us-east-1:123456789012:topic2",
204+ "Message": "Message 2",
205+ "MessageAttributes": {
206+ "_datadog": {
207+ "Type": "Binary",
208+ "Value": base64.b64encode(
209+ json.dumps({"dd-pathway-ctx-base64": "context2"})
210+ .encode("utf-8")
211+ ).decode("utf-8")
212+ }
213+ },
214+ }
215+ },
216+ {
217+ "EventSource": "aws:sns",
218+ "Sns": {
219+ "TopicArn": "arn:aws:sns:us-east-1:123456789012:topic3",
220+ "Message": "Message 3",
221+ "MessageAttributes": {
222+ "_datadog": {
223+ "Type": "Binary",
224+ "Value": base64.b64encode(
225+ json.dumps({"dd-pathway-ctx-base64": "context3"})
226+ .encode("utf-8")
227+ ).decode("utf-8")
228+ }
229+ },
230+ }
231+ },
232+ ]
233+ }
234+
235+ self.mock_get_dsm_context_from_lambda.side_effect = [
236+ {"dd-pathway-ctx-base64": "context1"},
237+ {"dd-pathway-ctx-base64": "context2"},
238+ {"dd-pathway-ctx-base64": "context3"},
239+ ]
240+
241+ _dsm_set_sns_context(multi_record_event)
242+
243+ self.assertEqual(self.mock_set_consume_checkpoint.call_count, 3)
244+
245+ calls = self.mock_set_consume_checkpoint.call_args_list
246+ expected_arns = [
247+ "arn:aws:sns:us-east-1:123456789012:topic1",
248+ "arn:aws:sns:us-east-1:123456789012:topic2",
249+ "arn:aws:sns:us-east-1:123456789012:topic3",
250+ ]
251+ expected_contexts = ["context1", "context2", "context3"]
252+
253+ for i, call in enumerate(calls):
254+ args, kwargs = call
255+ service_type = args[0]
256+ arn = args[1]
257+ carrier_get_func = args[2]
258+
259+ self.assertEqual(service_type, "sns")
260+
261+ self.assertEqual(arn, expected_arns[i])
262+
263+ pathway_ctx = carrier_get_func("dd-pathway-ctx-base64")
264+ self.assertEqual(pathway_ctx, expected_contexts[i])
265+
143266
144267class TestGetDSMContext(unittest.TestCase):
145268 def test_sqs_to_lambda_string_value_format(self):
@@ -188,6 +311,43 @@ def test_sqs_to_lambda_string_value_format(self):
188311 assert result["x-datadog-parent-id"] == "321987654"
189312 assert result["dd-pathway-ctx"] == "test-pathway-ctx"
190313
314+ def test_sns_to_lambda_format(self):
315+ """Test format: message.Sns.MessageAttributes._datadog.Value.decode() (SNS -> lambda)"""
316+ trace_context = {
317+ "x-datadog-trace-id": "111111111",
318+ "x-datadog-parent-id": "222222222",
319+ "dd-pathway-ctx": "test-pathway-ctx",
320+ }
321+ binary_data = base64.b64encode(
322+ json.dumps(trace_context).encode("utf-8")
323+ ).decode("utf-8")
324+
325+ sns_lambda_record = {
326+ "EventSource": "aws:sns",
327+ "EventSubscriptionArn": (
328+ "arn:aws:sns:us-east-1:123456789012:sns-topic:12345678-1234-1234-1234-123456789012"
329+ ),
330+ "Sns": {
331+ "Type": "Notification",
332+ "MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e",
333+ "TopicArn": "arn:aws:sns:us-east-1:123456789012:sns-topic",
334+ "Subject": "Test Subject",
335+ "Message": "Hello from SNS!",
336+ "Timestamp": "2023-01-01T12:00:00.000Z",
337+ "MessageAttributes": {
338+ "_datadog": {"Type": "Binary", "Value": binary_data}
339+ },
340+ },
341+ }
342+
343+ result = _get_dsm_context_from_lambda(sns_lambda_record)
344+
345+ assert result is not None
346+ assert result == trace_context
347+ assert result["x-datadog-trace-id"] == "111111111"
348+ assert result["x-datadog-parent-id"] == "222222222"
349+ assert result["dd-pathway-ctx"] == "test-pathway-ctx"
350+
191351 def test_no_message_attributes(self):
192352 """Test message without MessageAttributes returns None."""
193353 message = {
0 commit comments