-
Notifications
You must be signed in to change notification settings - Fork 1.5k
[WIP] Allow customization of prompts(ModelRetry,ToolReturn), tool_args, tool_descriptions sent to the Model. #3656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 38 commits
543d414
c5686de
8d9d9b9
4968c1f
5d04126
b901be7
ea4a9b8
933022f
6cc9b1d
c8ebcea
7eaa90b
16c4d92
edb115a
c1d77cf
e8de0b3
086e035
f23b841
1ef50f3
acc5420
a34c391
5920092
4ac181f
ef8cc54
874d70e
c4ef9ba
71608af
e368175
e7fc0c9
eefe430
d2d0498
d4a0c2d
2e8a1f2
f8b5026
b3632b7
9bebf4f
59981c1
987293e
400b34e
3c6ea8e
def1747
74c6e23
9aadb71
41f4f2b
8253d8f
09b2597
0477465
f5fb994
b6415aa
339ea74
8141c3a
fee446d
45dff51
0f729f0
3570d40
946a20b
454bda1
da87aa5
539be42
65e1321
1ef3ddc
05d031e
6723457
e28a4ff
a31598b
04b6f14
c979029
201d7f6
cdc477d
81755bb
7df0a25
165e795
f39cc66
90bfab2
c91b2ab
12c3ad6
4069ad8
378d0e6
2c1fe89
2c74c5a
c36ee12
2bde52a
18b2fa8
ca0c29c
e5285f0
ecf13ce
60c1f89
b0aa837
aeef8ba
d9a8ff9
5b66710
bdd2ea1
65c5f14
6aa80f9
38f7d0d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -138,6 +138,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): | |
|
|
||
| model: models.Model | ||
| model_settings: ModelSettings | None | ||
| prompt_templates: _messages.PromptTemplates | None | ||
| usage_limits: _usage.UsageLimits | ||
| max_result_retries: int | ||
| end_strategy: EndStrategy | ||
|
|
@@ -509,6 +510,10 @@ async def _prepare_request( | |
| # Update the new message index to ensure `result.new_messages()` returns the correct messages | ||
| ctx.deps.new_message_index -= len(original_history) - len(message_history) | ||
|
|
||
| prompt_templates = ctx.deps.prompt_templates | ||
| if prompt_templates: | ||
| _apply_prompt_templates(message_history, prompt_templates, run_context) | ||
|
|
||
| # Merge possible consecutive trailing `ModelRequest`s into one, with tool call parts before user parts, | ||
| # but don't store it in the message history on state. This is just for the benefit of model classes that want clear user/assistant boundaries. | ||
| # See `tests/test_tools.py::test_parallel_tool_return_with_deferred` for an example where this is necessary | ||
|
|
@@ -783,6 +788,11 @@ def _handle_final_result( | |
| ) -> End[result.FinalResult[NodeRunEndT]]: | ||
| messages = ctx.state.message_history | ||
|
|
||
| if tool_responses and ctx.deps.prompt_templates: | ||
| run_ctx = build_run_context(ctx) | ||
| for part in tool_responses: | ||
| ctx.deps.prompt_templates.apply_template(part, run_ctx) | ||
|
|
||
| # For backwards compatibility, append a new ModelRequest using the tool returns and retries | ||
| if tool_responses: | ||
| messages.append(_messages.ModelRequest(parts=tool_responses, run_id=ctx.state.run_id)) | ||
|
|
@@ -871,13 +881,15 @@ async def process_tool_calls( # noqa: C901 | |
| tool_name=call.tool_name, | ||
| content='Final result processed.', | ||
adtyavrdhn marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| tool_call_id=call.tool_call_id, | ||
| return_kind='final-result-processed', | ||
| ) | ||
| else: | ||
| yield _messages.FunctionToolCallEvent(call) | ||
| part = _messages.ToolReturnPart( | ||
| tool_name=call.tool_name, | ||
| content='Output tool not used - a final result was already processed.', | ||
| tool_call_id=call.tool_call_id, | ||
| return_kind='output-tool-not-executed', | ||
| ) | ||
| yield _messages.FunctionToolResultEvent(part) | ||
|
|
||
|
|
@@ -902,6 +914,7 @@ async def process_tool_calls( # noqa: C901 | |
| tool_name=call.tool_name, | ||
| content='Final result processed.', | ||
| tool_call_id=call.tool_call_id, | ||
| return_kind='final-result-processed', | ||
| ) | ||
| output_parts.append(part) | ||
| final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) | ||
|
|
@@ -915,6 +928,7 @@ async def process_tool_calls( # noqa: C901 | |
| tool_name=call.tool_name, | ||
| content='Tool not executed - a final result was already processed.', | ||
| tool_call_id=call.tool_call_id, | ||
| return_kind='function-tool-not-executed', | ||
| ) | ||
| ) | ||
| else: | ||
|
|
@@ -973,6 +987,7 @@ async def process_tool_calls( # noqa: C901 | |
| tool_name=call.tool_name, | ||
| content='Tool not executed - a final result was already processed.', | ||
| tool_call_id=call.tool_call_id, | ||
| return_kind='function-tool-not-executed', | ||
| ) | ||
| ) | ||
| elif calls: | ||
|
|
@@ -1129,6 +1144,7 @@ async def _call_tool( | |
| tool_name=tool_call.tool_name, | ||
| content=tool_call_result.message, | ||
adtyavrdhn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| tool_call_id=tool_call.tool_call_id, | ||
| return_kind='tool-denied', | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not too opposed to having this new field, but I wonder if it's strictly necessary. Since we build the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we could but I like the idea of this kind in the messages, I think the visibility of the ToolReturnPart's context increases. |
||
| ), None | ||
| elif isinstance(tool_call_result, exceptions.ModelRetry): | ||
| m = _messages.RetryPromptPart( | ||
|
|
@@ -1191,6 +1207,7 @@ async def _call_tool( | |
| tool_call_id=tool_call.tool_call_id, | ||
| content=tool_return.return_value, # type: ignore | ||
| metadata=tool_return.metadata, | ||
| return_kind='tool-executed', | ||
| ) | ||
|
|
||
| return return_part, tool_return.content or None | ||
|
|
@@ -1361,3 +1378,11 @@ def _clean_message_history(messages: list[_messages.ModelMessage]) -> list[_mess | |
| else: | ||
| clean_messages.append(message) | ||
| return clean_messages | ||
|
|
||
|
|
||
| def _apply_prompt_templates( | ||
| messages: list[_messages.ModelMessage], prompt_templates: _messages.PromptTemplates, ctx: RunContext[Any] | ||
| ): | ||
| for msg in messages: | ||
| for msg_part in msg.parts: | ||
| prompt_templates.apply_template(msg_part, ctx) | ||
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not modify the
ModelRequests or parts in the history in place, so please have this method build new object and a new list.We should make sure the modified final
ModelRequestshows up inresult.all_messages()etc, so we needctx.state.message_history[:] = message_historyas aboveThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only applies to message history passed in by the user, that's not our to edit. We can of course modify
self.requestand otherModelRequests we build ourselves