Skip to content

Commit 59aba68

Browse files
committed
Add FunctionInvocationId header propagation for client operations
Propagates the Azure Functions invocation ID to the Durable Functions host via the X-Azure-Functions-InvocationId HTTP header, enabling correlation between worker-side function invocations and host-side orchestration events. - Modified http_utils.py to accept optional function_invocation_id parameter - Updated DurableOrchestrationClient to pass invocation ID to HTTP calls - Added optional function_invocation_id parameter to DurableApp.client decorator - Added unit tests for header propagation Related to Azure/azure-functions-durable-extension#3317
1 parent 4ab004b commit 59aba68

File tree

5 files changed

+143
-15
lines changed

5 files changed

+143
-15
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## Unreleased
6+
7+
### Added
8+
9+
- Client operation correlation logging: `FunctionInvocationId` is now propagated via HTTP headers to the host for client operations, enabling correlation with host logs.
10+
511
## 1.0.0b6
612

713
- [Create timer](https://github.com/Azure/azure-functions-durable-python/issues/35) functionality available

azure/durable_functions/decorators/durable_app.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,14 @@ async def df_client_middleware(*args, **kwargs):
195195
# construct rich object from it,
196196
# and assign parameter to that rich object
197197
starter = kwargs[parameter_name]
198-
client = client_constructor(starter)
198+
199+
# Try to extract the function invocation ID from the context for correlation
200+
function_invocation_id = None
201+
context = kwargs.get('context')
202+
if context is not None and hasattr(context, 'invocation_id'):
203+
function_invocation_id = context.invocation_id
204+
205+
client = client_constructor(starter, function_invocation_id)
199206
kwargs[parameter_name] = client
200207

201208
# Invoke user code with rich DF Client binding

azure/durable_functions/models/DurableOrchestrationClient.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,16 @@ class DurableOrchestrationClient:
2626
orchestration instances.
2727
"""
2828

29-
def __init__(self, context: str):
29+
def __init__(self, context: str, function_invocation_id: Optional[str] = None):
30+
"""Initialize a DurableOrchestrationClient.
31+
32+
Parameters
33+
----------
34+
context : str
35+
The JSON-encoded client binding context.
36+
function_invocation_id : Optional[str]
37+
The function invocation ID for correlation with host-side logs.
38+
"""
3039
self.task_hub_name: str
3140
self._uniqueWebHookOrigins: List[str]
3241
self._event_name_placeholder: str = "{eventName}"
@@ -39,6 +48,7 @@ def __init__(self, context: str):
3948
self._show_history_query_key: str = "showHistory"
4049
self._show_history_output_query_key: str = "showHistoryOutput"
4150
self._show_input_query_key: str = "showInput"
51+
self._function_invocation_id: Optional[str] = function_invocation_id
4252
self._orchestration_bindings: DurableOrchestrationBindings = \
4353
DurableOrchestrationBindings.from_json(context)
4454
self._post_async_request = post_async_request
@@ -84,7 +94,8 @@ async def start_new(self,
8494
request_url,
8595
self._get_json_input(client_input),
8696
trace_parent,
87-
trace_state)
97+
trace_state,
98+
self._function_invocation_id)
8899

89100
status_code: int = response[0]
90101
if status_code <= 202 and response[1]:
@@ -100,6 +111,7 @@ async def start_new(self,
100111
ex_message: Any = response[1]
101112
raise Exception(ex_message)
102113

114+
103115
def create_check_status_response(
104116
self, request: func.HttpRequest, instance_id: str) -> func.HttpResponse:
105117
"""Create a HttpResponse that contains useful information for \
@@ -256,7 +268,10 @@ async def raise_event(
256268
request_url = self._get_raise_event_url(
257269
instance_id, event_name, task_hub_name, connection_name)
258270

259-
response = await self._post_async_request(request_url, json.dumps(event_data))
271+
response = await self._post_async_request(
272+
request_url,
273+
json.dumps(event_data),
274+
function_invocation_id=self._function_invocation_id)
260275

261276
switch_statement = {
262277
202: lambda: None,
@@ -445,7 +460,10 @@ async def terminate(self, instance_id: str, reason: str) -> None:
445460
"""
446461
request_url = f"{self._orchestration_bindings.rpc_base_url}instances/{instance_id}/" \
447462
f"terminate?reason={quote(reason)}"
448-
response = await self._post_async_request(request_url, None)
463+
response = await self._post_async_request(
464+
request_url,
465+
None,
466+
function_invocation_id=self._function_invocation_id)
449467
switch_statement = {
450468
202: lambda: None, # instance in progress
451469
410: lambda: None, # instance failed or terminated
@@ -564,7 +582,8 @@ async def signal_entity(self, entityId: EntityId, operation_name: str,
564582
request_url,
565583
json.dumps(operation_input) if operation_input else None,
566584
trace_parent,
567-
trace_state)
585+
trace_state,
586+
self._function_invocation_id)
568587

569588
switch_statement = {
570589
202: lambda: None # signal accepted
@@ -714,7 +733,10 @@ async def rewind(self,
714733
raise Exception("The Python SDK only supports RPC endpoints."
715734
+ "Please remove the `localRpcEnabled` setting from host.json")
716735

717-
response = await self._post_async_request(request_url, None)
736+
response = await self._post_async_request(
737+
request_url,
738+
None,
739+
function_invocation_id=self._function_invocation_id)
718740
status: int = response[0]
719741
ex_msg: str = ""
720742
if status == 200 or status == 202:
@@ -753,7 +775,10 @@ async def suspend(self, instance_id: str, reason: str) -> None:
753775
"""
754776
request_url = f"{self._orchestration_bindings.rpc_base_url}instances/{instance_id}/" \
755777
f"suspend?reason={quote(reason)}"
756-
response = await self._post_async_request(request_url, None)
778+
response = await self._post_async_request(
779+
request_url,
780+
None,
781+
function_invocation_id=self._function_invocation_id)
757782
switch_statement = {
758783
202: lambda: None, # instance is suspended
759784
410: lambda: None, # instance completed
@@ -788,7 +813,10 @@ async def resume(self, instance_id: str, reason: str) -> None:
788813
"""
789814
request_url = f"{self._orchestration_bindings.rpc_base_url}instances/{instance_id}/" \
790815
f"resume?reason={quote(reason)}"
791-
response = await self._post_async_request(request_url, None)
816+
response = await self._post_async_request(
817+
request_url,
818+
None,
819+
function_invocation_id=self._function_invocation_id)
792820
switch_statement = {
793821
202: lambda: None, # instance is resumed
794822
410: lambda: None, # instance completed

azure/durable_functions/models/utils/http_utils.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from typing import Any, List, Union
1+
from typing import Any, List, Union, Optional
22

33
import aiohttp
44

55

66
async def post_async_request(url: str,
77
data: Any = None,
88
trace_parent: str = None,
9-
trace_state: str = None) -> List[Union[int, Any]]:
9+
trace_state: str = None,
10+
function_invocation_id: str = None) -> List[Union[int, Any]]:
1011
"""Post request with the data provided to the url provided.
1112
1213
Parameters
@@ -19,6 +20,8 @@ async def post_async_request(url: str,
1920
traceparent header to send with the request
2021
trace_state: str
2122
tracestate header to send with the request
23+
function_invocation_id: str
24+
function invocation ID header to send for correlation
2225
2326
Returns
2427
-------
@@ -31,6 +34,8 @@ async def post_async_request(url: str,
3134
headers["traceparent"] = trace_parent
3235
if trace_state:
3336
headers["tracestate"] = trace_state
37+
if function_invocation_id:
38+
headers["X-Azure-Functions-InvocationId"] = function_invocation_id
3439
async with session.post(url, json=data, headers=headers) as response:
3540
# We disable aiohttp's input type validation
3641
# as the server may respond with alternative
@@ -40,41 +45,53 @@ async def post_async_request(url: str,
4045
return [response.status, data]
4146

4247

43-
async def get_async_request(url: str) -> List[Any]:
48+
async def get_async_request(url: str,
49+
function_invocation_id: str = None) -> List[Any]:
4450
"""Get the data from the url provided.
4551
4652
Parameters
4753
----------
4854
url: str
4955
url to get the data from
56+
function_invocation_id: str
57+
function invocation ID header to send for correlation
5058
5159
Returns
5260
-------
5361
[int, Any]
5462
Tuple with the Response status code and the data returned from the request
5563
"""
5664
async with aiohttp.ClientSession() as session:
57-
async with session.get(url) as response:
65+
headers = {}
66+
if function_invocation_id:
67+
headers["X-Azure-Functions-InvocationId"] = function_invocation_id
68+
async with session.get(url, headers=headers) as response:
5869
data = await response.json(content_type=None)
5970
if data is None:
6071
data = ""
6172
return [response.status, data]
6273

6374

64-
async def delete_async_request(url: str) -> List[Union[int, Any]]:
75+
async def delete_async_request(url: str,
76+
function_invocation_id: str = None) -> List[Union[int, Any]]:
6577
"""Delete the data from the url provided.
6678
6779
Parameters
6880
----------
6981
url: str
7082
url to delete the data from
83+
function_invocation_id: str
84+
function invocation ID header to send for correlation
7185
7286
Returns
7387
-------
7488
[int, Any]
7589
Tuple with the Response status code and the data returned from the request
7690
"""
7791
async with aiohttp.ClientSession() as session:
78-
async with session.delete(url) as response:
92+
headers = {}
93+
if function_invocation_id:
94+
headers["X-Azure-Functions-InvocationId"] = function_invocation_id
95+
async with session.delete(url, headers=headers) as response:
7996
data = await response.json(content_type=None)
8097
return [response.status, data]

tests/models/test_DurableOrchestrationClient.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,3 +739,73 @@ async def test_post_500_resume(binding_string):
739739

740740
with pytest.raises(Exception):
741741
await client.resume(TEST_INSTANCE_ID, raw_reason)
742+
743+
744+
# Tests for function_invocation_id parameter
745+
def test_client_stores_function_invocation_id(binding_string):
746+
"""Test that the client stores the function_invocation_id parameter."""
747+
invocation_id = "test-invocation-123"
748+
client = DurableOrchestrationClient(binding_string, function_invocation_id=invocation_id)
749+
assert client._function_invocation_id == invocation_id
750+
751+
752+
def test_client_stores_none_when_no_invocation_id(binding_string):
753+
"""Test that the client stores None when no invocation ID is provided."""
754+
client = DurableOrchestrationClient(binding_string)
755+
assert client._function_invocation_id is None
756+
757+
758+
class MockRequestWithInvocationId:
759+
"""Mock request class that verifies function_invocation_id is passed."""
760+
761+
def __init__(self, expected_url: str, response: [int, any], expected_invocation_id: str = None):
762+
self._expected_url = expected_url
763+
self._response = response
764+
self._expected_invocation_id = expected_invocation_id
765+
self._received_invocation_id = None
766+
767+
@property
768+
def received_invocation_id(self):
769+
return self._received_invocation_id
770+
771+
async def post(self, url: str, data: Any = None, trace_parent: str = None,
772+
trace_state: str = None, function_invocation_id: str = None):
773+
assert url == self._expected_url
774+
self._received_invocation_id = function_invocation_id
775+
if self._expected_invocation_id is not None:
776+
assert function_invocation_id == self._expected_invocation_id
777+
return self._response
778+
779+
780+
@pytest.mark.asyncio
781+
async def test_start_new_passes_invocation_id(binding_string):
782+
"""Test that start_new passes the function_invocation_id to the HTTP request."""
783+
invocation_id = "test-invocation-456"
784+
function_name = "MyOrchestrator"
785+
786+
mock_request = MockRequestWithInvocationId(
787+
expected_url=f"{RPC_BASE_URL}orchestrators/{function_name}",
788+
response=[202, {"id": TEST_INSTANCE_ID}],
789+
expected_invocation_id=invocation_id)
790+
791+
client = DurableOrchestrationClient(binding_string, function_invocation_id=invocation_id)
792+
client._post_async_request = mock_request.post
793+
794+
await client.start_new(function_name)
795+
assert mock_request.received_invocation_id == invocation_id
796+
797+
798+
@pytest.mark.asyncio
799+
async def test_start_new_passes_none_when_no_invocation_id(binding_string):
800+
"""Test that start_new passes None when no invocation ID is provided."""
801+
function_name = "MyOrchestrator"
802+
803+
mock_request = MockRequestWithInvocationId(
804+
expected_url=f"{RPC_BASE_URL}orchestrators/{function_name}",
805+
response=[202, {"id": TEST_INSTANCE_ID}])
806+
807+
client = DurableOrchestrationClient(binding_string)
808+
client._post_async_request = mock_request.post
809+
810+
await client.start_new(function_name)
811+
assert mock_request.received_invocation_id is None

0 commit comments

Comments
 (0)