Skip to content

Commit eeb2cd0

Browse files
committed
feat: attach trace context to LiteLLM headers
1 parent 69997cd commit eeb2cd0

File tree

2 files changed

+175
-0
lines changed

2 files changed

+175
-0
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from litellm import Message
5454
from litellm import ModelResponse
5555
from litellm import OpenAIMessageContent
56+
from opentelemetry import trace
5657
from pydantic import BaseModel
5758
from pydantic import Field
5859
from typing_extensions import override
@@ -274,6 +275,39 @@ class UsageMetadataChunk(BaseModel):
274275
class LiteLLMClient:
275276
"""Provides acompletion method (for better testability)."""
276277

278+
@staticmethod
279+
def _build_traceparent() -> Optional[str]:
280+
span_context = trace.get_current_span().get_span_context()
281+
if not span_context.is_valid:
282+
return None
283+
284+
trace_id = f"{span_context.trace_id:032x}"
285+
span_id = f"{span_context.span_id:016x}"
286+
trace_flags = f"{int(span_context.trace_flags):02x}"
287+
return f"00-{trace_id}-{span_id}-{trace_flags}"
288+
289+
@classmethod
290+
def _maybe_add_traceparent_header(
291+
cls, extra_headers: Optional[dict[str, str]]
292+
) -> Optional[dict[str, str]]:
293+
traceparent = cls._build_traceparent()
294+
if not traceparent:
295+
return extra_headers
296+
297+
headers_with_trace = dict(extra_headers) if extra_headers else {}
298+
headers_with_trace["traceparent"] = traceparent
299+
return headers_with_trace
300+
301+
@classmethod
302+
def _attach_traceparent_header(cls, kwargs: Dict[str, Any]) -> None:
303+
updated_headers = cls._maybe_add_traceparent_header(
304+
kwargs.get("extra_headers")
305+
)
306+
if updated_headers is None:
307+
kwargs.pop("extra_headers", None)
308+
else:
309+
kwargs["extra_headers"] = updated_headers
310+
277311
async def acompletion(
278312
self, model, messages, tools, **kwargs
279313
) -> Union[ModelResponse, CustomStreamWrapper]:
@@ -289,6 +323,8 @@ async def acompletion(
289323
The model response as a message.
290324
"""
291325

326+
self._attach_traceparent_header(kwargs)
327+
292328
return await acompletion(
293329
model=model,
294330
messages=messages,
@@ -312,6 +348,8 @@ def completion(
312348
The response from the model.
313349
"""
314350

351+
self._attach_traceparent_header(kwargs)
352+
315353
return completion(
316354
model=model,
317355
messages=messages,

tests/unittests/models/test_litellm.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from google.adk.models.lite_llm import LiteLLMClient
4646
from google.adk.models.lite_llm import TextChunk
4747
from google.adk.models.lite_llm import UsageMetadataChunk
48+
import google.adk.models.lite_llm as lite_llm_module
4849
from google.adk.models.llm_request import LlmRequest
4950
from google.genai import types
5051
import litellm
@@ -56,6 +57,9 @@
5657
from litellm.types.utils import Delta
5758
from litellm.types.utils import ModelResponse
5859
from litellm.types.utils import StreamingChoices
60+
from opentelemetry.trace import SpanContext
61+
from opentelemetry.trace import TraceFlags
62+
from opentelemetry.trace import TraceState
5963
from pydantic import BaseModel
6064
from pydantic import Field
6165
import pytest
@@ -218,6 +222,139 @@
218222
]
219223

220224

225+
class _StubSpan:
226+
227+
def __init__(self, span_context):
228+
self._span_context = span_context
229+
230+
def get_span_context(self):
231+
return self._span_context
232+
233+
234+
def _build_valid_span_context():
235+
return SpanContext(
236+
trace_id=int("0123456789abcdef0123456789abcdef", 16),
237+
span_id=int("abcdef0123456789", 16),
238+
is_remote=False,
239+
trace_flags=TraceFlags(1),
240+
trace_state=TraceState(),
241+
)
242+
243+
244+
def _build_invalid_span_context():
245+
return SpanContext(
246+
trace_id=0,
247+
span_id=0,
248+
is_remote=False,
249+
trace_flags=TraceFlags(0),
250+
trace_state=TraceState(),
251+
)
252+
253+
254+
def test_maybe_add_traceparent_header_with_existing_headers(monkeypatch):
255+
span_context = _build_valid_span_context()
256+
monkeypatch.setattr(
257+
lite_llm_module.trace,
258+
"get_current_span",
259+
lambda: _StubSpan(span_context),
260+
)
261+
262+
headers = {"custom": "header"}
263+
result = LiteLLMClient._maybe_add_traceparent_header(headers)
264+
265+
assert result is not headers
266+
assert result["custom"] == "header"
267+
assert result["traceparent"] == (
268+
"00-0123456789abcdef0123456789abcdef-abcdef0123456789-01"
269+
)
270+
271+
272+
def test_maybe_add_traceparent_header_without_existing_headers(monkeypatch):
273+
span_context = _build_valid_span_context()
274+
monkeypatch.setattr(
275+
lite_llm_module.trace,
276+
"get_current_span",
277+
lambda: _StubSpan(span_context),
278+
)
279+
280+
result = LiteLLMClient._maybe_add_traceparent_header(None)
281+
282+
assert result == {
283+
"traceparent": "00-0123456789abcdef0123456789abcdef-abcdef0123456789-01"
284+
}
285+
286+
287+
def test_maybe_add_traceparent_header_without_active_span(monkeypatch):
288+
span_context = _build_invalid_span_context()
289+
monkeypatch.setattr(
290+
lite_llm_module.trace,
291+
"get_current_span",
292+
lambda: _StubSpan(span_context),
293+
)
294+
295+
headers = {"custom": "value"}
296+
result = LiteLLMClient._maybe_add_traceparent_header(headers)
297+
298+
assert result is headers
299+
300+
301+
@pytest.mark.asyncio
302+
async def test_litellmclient_acompletion_sets_traceparent_header(monkeypatch):
303+
async_mock = AsyncMock(return_value="response")
304+
monkeypatch.setattr(lite_llm_module, "acompletion", async_mock)
305+
306+
def fake_helper(headers):
307+
assert headers == {"existing": "header"}
308+
return {"existing": "header", "traceparent": "tp"}
309+
310+
monkeypatch.setattr(
311+
LiteLLMClient, "_maybe_add_traceparent_header", fake_helper
312+
)
313+
314+
client = LiteLLMClient()
315+
await client.acompletion(
316+
model="test",
317+
messages=[],
318+
tools=None,
319+
extra_headers={"existing": "header"},
320+
custom="value",
321+
)
322+
323+
async_mock.assert_awaited_once()
324+
_, kwargs = async_mock.call_args
325+
assert kwargs["extra_headers"] == {
326+
"existing": "header",
327+
"traceparent": "tp",
328+
}
329+
assert kwargs["custom"] == "value"
330+
331+
332+
def test_litellmclient_completion_sets_traceparent_header(monkeypatch):
333+
sync_mock = Mock(return_value="response")
334+
monkeypatch.setattr(lite_llm_module, "completion", sync_mock)
335+
336+
def fake_helper(headers):
337+
assert headers is None
338+
return {"traceparent": "tp"}
339+
340+
monkeypatch.setattr(
341+
LiteLLMClient, "_maybe_add_traceparent_header", fake_helper
342+
)
343+
344+
client = LiteLLMClient()
345+
client.completion(
346+
model="test",
347+
messages=[],
348+
tools=None,
349+
stream=True,
350+
)
351+
352+
sync_mock.assert_called_once()
353+
_, kwargs = sync_mock.call_args
354+
assert kwargs["extra_headers"] == {"traceparent": "tp"}
355+
assert kwargs["stream"]
356+
357+
221358
class _StructuredOutput(BaseModel):
222359
value: int = Field(description="Value to emit")
223360

0 commit comments

Comments
 (0)