|
45 | 45 | from google.adk.models.lite_llm import LiteLLMClient |
46 | 46 | from google.adk.models.lite_llm import TextChunk |
47 | 47 | from google.adk.models.lite_llm import UsageMetadataChunk |
| 48 | +import google.adk.models.lite_llm as lite_llm_module |
48 | 49 | from google.adk.models.llm_request import LlmRequest |
49 | 50 | from google.genai import types |
50 | 51 | import litellm |
|
56 | 57 | from litellm.types.utils import Delta |
57 | 58 | from litellm.types.utils import ModelResponse |
58 | 59 | from litellm.types.utils import StreamingChoices |
| 60 | +from opentelemetry.trace import SpanContext |
| 61 | +from opentelemetry.trace import TraceFlags |
| 62 | +from opentelemetry.trace import TraceState |
59 | 63 | from pydantic import BaseModel |
60 | 64 | from pydantic import Field |
61 | 65 | import pytest |
|
218 | 222 | ] |
219 | 223 |
|
220 | 224 |
|
| 225 | +class _StubSpan: |
| 226 | + |
| 227 | + def __init__(self, span_context): |
| 228 | + self._span_context = span_context |
| 229 | + |
| 230 | + def get_span_context(self): |
| 231 | + return self._span_context |
| 232 | + |
| 233 | + |
| 234 | +def _build_valid_span_context(): |
| 235 | + return SpanContext( |
| 236 | + trace_id=int("0123456789abcdef0123456789abcdef", 16), |
| 237 | + span_id=int("abcdef0123456789", 16), |
| 238 | + is_remote=False, |
| 239 | + trace_flags=TraceFlags(1), |
| 240 | + trace_state=TraceState(), |
| 241 | + ) |
| 242 | + |
| 243 | + |
| 244 | +def _build_invalid_span_context(): |
| 245 | + return SpanContext( |
| 246 | + trace_id=0, |
| 247 | + span_id=0, |
| 248 | + is_remote=False, |
| 249 | + trace_flags=TraceFlags(0), |
| 250 | + trace_state=TraceState(), |
| 251 | + ) |
| 252 | + |
| 253 | + |
| 254 | +def test_maybe_add_traceparent_header_with_existing_headers(monkeypatch): |
| 255 | + span_context = _build_valid_span_context() |
| 256 | + monkeypatch.setattr( |
| 257 | + lite_llm_module.trace, |
| 258 | + "get_current_span", |
| 259 | + lambda: _StubSpan(span_context), |
| 260 | + ) |
| 261 | + |
| 262 | + headers = {"custom": "header"} |
| 263 | + result = LiteLLMClient._maybe_add_traceparent_header(headers) |
| 264 | + |
| 265 | + assert result is not headers |
| 266 | + assert result["custom"] == "header" |
| 267 | + assert result["traceparent"] == ( |
| 268 | + "00-0123456789abcdef0123456789abcdef-abcdef0123456789-01" |
| 269 | + ) |
| 270 | + |
| 271 | + |
| 272 | +def test_maybe_add_traceparent_header_without_existing_headers(monkeypatch): |
| 273 | + span_context = _build_valid_span_context() |
| 274 | + monkeypatch.setattr( |
| 275 | + lite_llm_module.trace, |
| 276 | + "get_current_span", |
| 277 | + lambda: _StubSpan(span_context), |
| 278 | + ) |
| 279 | + |
| 280 | + result = LiteLLMClient._maybe_add_traceparent_header(None) |
| 281 | + |
| 282 | + assert result == { |
| 283 | + "traceparent": "00-0123456789abcdef0123456789abcdef-abcdef0123456789-01" |
| 284 | + } |
| 285 | + |
| 286 | + |
| 287 | +def test_maybe_add_traceparent_header_without_active_span(monkeypatch): |
| 288 | + span_context = _build_invalid_span_context() |
| 289 | + monkeypatch.setattr( |
| 290 | + lite_llm_module.trace, |
| 291 | + "get_current_span", |
| 292 | + lambda: _StubSpan(span_context), |
| 293 | + ) |
| 294 | + |
| 295 | + headers = {"custom": "value"} |
| 296 | + result = LiteLLMClient._maybe_add_traceparent_header(headers) |
| 297 | + |
| 298 | + assert result is headers |
| 299 | + |
| 300 | + |
| 301 | +@pytest.mark.asyncio |
| 302 | +async def test_litellmclient_acompletion_sets_traceparent_header(monkeypatch): |
| 303 | + async_mock = AsyncMock(return_value="response") |
| 304 | + monkeypatch.setattr(lite_llm_module, "acompletion", async_mock) |
| 305 | + |
| 306 | + def fake_helper(headers): |
| 307 | + assert headers == {"existing": "header"} |
| 308 | + return {"existing": "header", "traceparent": "tp"} |
| 309 | + |
| 310 | + monkeypatch.setattr( |
| 311 | + LiteLLMClient, "_maybe_add_traceparent_header", fake_helper |
| 312 | + ) |
| 313 | + |
| 314 | + client = LiteLLMClient() |
| 315 | + await client.acompletion( |
| 316 | + model="test", |
| 317 | + messages=[], |
| 318 | + tools=None, |
| 319 | + extra_headers={"existing": "header"}, |
| 320 | + custom="value", |
| 321 | + ) |
| 322 | + |
| 323 | + async_mock.assert_awaited_once() |
| 324 | + _, kwargs = async_mock.call_args |
| 325 | + assert kwargs["extra_headers"] == { |
| 326 | + "existing": "header", |
| 327 | + "traceparent": "tp", |
| 328 | + } |
| 329 | + assert kwargs["custom"] == "value" |
| 330 | + |
| 331 | + |
| 332 | +def test_litellmclient_completion_sets_traceparent_header(monkeypatch): |
| 333 | + sync_mock = Mock(return_value="response") |
| 334 | + monkeypatch.setattr(lite_llm_module, "completion", sync_mock) |
| 335 | + |
| 336 | + def fake_helper(headers): |
| 337 | + assert headers is None |
| 338 | + return {"traceparent": "tp"} |
| 339 | + |
| 340 | + monkeypatch.setattr( |
| 341 | + LiteLLMClient, "_maybe_add_traceparent_header", fake_helper |
| 342 | + ) |
| 343 | + |
| 344 | + client = LiteLLMClient() |
| 345 | + client.completion( |
| 346 | + model="test", |
| 347 | + messages=[], |
| 348 | + tools=None, |
| 349 | + stream=True, |
| 350 | + ) |
| 351 | + |
| 352 | + sync_mock.assert_called_once() |
| 353 | + _, kwargs = sync_mock.call_args |
| 354 | + assert kwargs["extra_headers"] == {"traceparent": "tp"} |
| 355 | + assert kwargs["stream"] |
| 356 | + |
| 357 | + |
221 | 358 | class _StructuredOutput(BaseModel): |
222 | 359 | value: int = Field(description="Value to emit") |
223 | 360 |
|
|
0 commit comments