Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,46 @@
"before a response was recorded)."
)

_LITELLM_THOUGHT_SIGNATURE_SEPARATOR = "__thought__"


def _decode_litellm_tool_call_id(
tool_call_id: str,
) -> tuple[str, Optional[bytes]]:
"""Extracts thought_signature bytes from a LiteLLM tool call id."""
if not tool_call_id:
return tool_call_id, None

base_id, separator, encoded_signature = tool_call_id.partition(
_LITELLM_THOUGHT_SIGNATURE_SEPARATOR
)
if not separator or not encoded_signature:
return base_id, None

try:
return base_id, base64.b64decode(encoded_signature)
except (ValueError, TypeError) as err:
logger.warning(
"Failed to decode thought_signature from tool call id %r: %s",
tool_call_id,
err,
)
return base_id, None


def _encode_litellm_tool_call_id(
tool_call_id: Optional[str], thought_signature: Optional[bytes]
) -> Optional[str]:
"""Embeds thought_signature bytes in a LiteLLM-compatible tool call id."""
if not tool_call_id or not thought_signature:
return tool_call_id

encoded_signature = base64.b64encode(thought_signature).decode("utf-8")
return (
f"{tool_call_id}{_LITELLM_THOUGHT_SIGNATURE_SEPARATOR}"
f"{encoded_signature}"
)

_LITELLM_IMPORTED = False
_LITELLM_GLOBAL_SYMBOLS = (
"ChatCompletionAssistantMessage",
Expand Down Expand Up @@ -665,7 +705,10 @@ async def _content_to_message_param(
tool_calls.append(
ChatCompletionAssistantToolCall(
type="function",
id=part.function_call.id,
id=_encode_litellm_tool_call_id(
part.function_call.id,
part.thought_signature,
),
function=Function(
name=part.function_call.name,
arguments=_safe_json_serialize(part.function_call.args),
Expand Down Expand Up @@ -1481,7 +1524,12 @@ def _message_to_generate_content_response(
name=tool_call.function.name,
args=json.loads(tool_call.function.arguments or "{}"),
)
part.function_call.id = tool_call.id
tool_call_id, thought_signature = _decode_litellm_tool_call_id(
tool_call.id
)
part.function_call.id = tool_call_id
if thought_signature:
part.thought_signature = thought_signature
parts.append(part)

return LlmResponse(
Expand Down
51 changes: 51 additions & 0 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License

import base64
import contextlib
import json
import logging
Expand Down Expand Up @@ -2217,6 +2218,56 @@ def test_message_to_generate_content_response_tool_call():
assert response.content.parts[0].function_call.id == "test_tool_call_id"


def test_message_to_generate_content_response_tool_call_with_thought_signature():
signature = b"gemini_signature"
encoded_signature = base64.b64encode(signature).decode("utf-8")
message = ChatCompletionAssistantMessage(
role="assistant",
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
type="function",
id=f"test_tool_call_id__thought__{encoded_signature}",
function=Function(
name="test_function",
arguments='{"test_arg": "test_value"}',
),
)
],
)

response = _message_to_generate_content_response(message)
assert response.content.role == "model"
assert response.content.parts[0].function_call.name == "test_function"
assert response.content.parts[0].function_call.args == {
"test_arg": "test_value"
}
assert response.content.parts[0].function_call.id == "test_tool_call_id"
assert response.content.parts[0].thought_signature == signature


@pytest.mark.asyncio
async def test_content_to_message_param_embeds_thought_signature_in_tool_call():
part = types.Part.from_function_call(
name="test_function",
args={"test_arg": "test_value"},
)
part.function_call.id = "test_tool_call_id"
part.thought_signature = b"gemini_signature"
content = types.Content(role="model", parts=[part])

message = await _content_to_message_param(content)

tool_calls = message["tool_calls"]
assert tool_calls is not None
assert len(tool_calls) == 1
assert (
tool_calls[0]["id"]
== "test_tool_call_id__thought__"
+ base64.b64encode(b"gemini_signature").decode("utf-8")
)


def test_message_to_generate_content_response_inline_tool_call_text():
message = ChatCompletionAssistantMessage(
role="assistant",
Expand Down
Loading