Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/bub/core/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,22 +166,30 @@ async def _chat(self, prompt: str) -> _ChatResult:
try:
async with asyncio.timeout(self._model_timeout_seconds):
provider, _, _ = self._model.partition(":")
if provider.casefold() == "vertexai":
provider = provider.casefold()
if provider == "vertexai":
output = await self._tape.tape.run_tools_async(
prompt=prompt,
system_prompt=system_prompt,
max_tokens=self._max_tokens,
tools=self._tools,
http_options={"headers": self.DEFAULT_HEADERS},
)
else:
elif provider == "openrouter":
output = await self._tape.tape.run_tools_async(
prompt=prompt,
system_prompt=system_prompt,
max_tokens=self._max_tokens,
tools=self._tools,
extra_headers=self.DEFAULT_HEADERS,
)
else:
output = await self._tape.tape.run_tools_async(
prompt=prompt,
system_prompt=system_prompt,
max_tokens=self._max_tokens,
tools=self._tools,
)
return _ChatResult.from_tool_auto(output)
except TimeoutError:
return _ChatResult(
Expand Down
27 changes: 25 additions & 2 deletions tests/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,30 @@ async def test_model_runner_maps_headers_for_vertexai() -> None:


@pytest.mark.asyncio
async def test_model_runner_uses_extra_headers_for_unknown_provider() -> None:
async def test_model_runner_does_not_send_headers_for_gemini() -> None:
tape = FakeTapeService(FakeTapeImpl(outputs=[ToolAutoResult.text_result("assistant-only")]))
runner = ModelRunner(
tape=tape, # type: ignore[arg-type]
router=SingleStepRouter(), # type: ignore[arg-type]
tool_view=FakeToolView(), # type: ignore[arg-type]
tools=[],
list_skills=lambda: [],
model="gemini:test",
max_steps=1,
max_tokens=512,
model_timeout_seconds=90,
base_system_prompt="base",
get_workspace_system_prompt=lambda: "",
)

await runner.run("hi")
kwargs = tape.tape.call_kwargs[0]
assert "extra_headers" not in kwargs
assert "http_options" not in kwargs


@pytest.mark.asyncio
async def test_model_runner_does_not_send_headers_for_unknown_provider() -> None:
tape = FakeTapeService(FakeTapeImpl(outputs=[ToolAutoResult.text_result("assistant-only")]))
runner = ModelRunner(
tape=tape, # type: ignore[arg-type]
Expand All @@ -450,5 +473,5 @@ async def test_model_runner_uses_extra_headers_for_unknown_provider() -> None:

await runner.run("hi")
kwargs = tape.tape.call_kwargs[0]
assert kwargs.get("extra_headers") == ModelRunner.DEFAULT_HEADERS
assert "extra_headers" not in kwargs
assert "http_options" not in kwargs