diff --git a/src/bub/core/model_runner.py b/src/bub/core/model_runner.py index cb2dc832..381aff4a 100644 --- a/src/bub/core/model_runner.py +++ b/src/bub/core/model_runner.py @@ -166,7 +166,8 @@ async def _chat(self, prompt: str) -> _ChatResult: try: async with asyncio.timeout(self._model_timeout_seconds): provider, _, _ = self._model.partition(":") - if provider.casefold() == "vertexai": + provider = provider.casefold() + if provider == "vertexai": output = await self._tape.tape.run_tools_async( prompt=prompt, system_prompt=system_prompt, @@ -174,7 +175,7 @@ async def _chat(self, prompt: str) -> _ChatResult: tools=self._tools, http_options={"headers": self.DEFAULT_HEADERS}, ) - else: + elif provider == "openrouter": output = await self._tape.tape.run_tools_async( prompt=prompt, system_prompt=system_prompt, @@ -182,6 +183,13 @@ async def _chat(self, prompt: str) -> _ChatResult: tools=self._tools, extra_headers=self.DEFAULT_HEADERS, ) + else: + output = await self._tape.tape.run_tools_async( + prompt=prompt, + system_prompt=system_prompt, + max_tokens=self._max_tokens, + tools=self._tools, + ) return _ChatResult.from_tool_auto(output) except TimeoutError: return _ChatResult( diff --git a/tests/test_model_runner.py b/tests/test_model_runner.py index 4a9ed0a4..ef611dfe 100644 --- a/tests/test_model_runner.py +++ b/tests/test_model_runner.py @@ -432,7 +432,30 @@ async def test_model_runner_maps_headers_for_vertexai() -> None: @pytest.mark.asyncio -async def test_model_runner_uses_extra_headers_for_unknown_provider() -> None: +async def test_model_runner_does_not_send_headers_for_gemini() -> None: + tape = FakeTapeService(FakeTapeImpl(outputs=[ToolAutoResult.text_result("assistant-only")])) + runner = ModelRunner( + tape=tape, # type: ignore[arg-type] + router=SingleStepRouter(), # type: ignore[arg-type] + tool_view=FakeToolView(), # type: ignore[arg-type] + tools=[], + list_skills=lambda: [], + model="gemini:test", + max_steps=1, + max_tokens=512, + model_timeout_seconds=90, + base_system_prompt="base", + get_workspace_system_prompt=lambda: "", + ) + + await runner.run("hi") + kwargs = tape.tape.call_kwargs[0] + assert "extra_headers" not in kwargs + assert "http_options" not in kwargs + + +@pytest.mark.asyncio +async def test_model_runner_does_not_send_headers_for_unknown_provider() -> None: tape = FakeTapeService(FakeTapeImpl(outputs=[ToolAutoResult.text_result("assistant-only")])) runner = ModelRunner( tape=tape, # type: ignore[arg-type] @@ -450,5 +473,5 @@ async def test_model_runner_uses_extra_headers_for_unknown_provider() -> None: await runner.run("hi") kwargs = tape.tape.call_kwargs[0] - assert kwargs.get("extra_headers") == ModelRunner.DEFAULT_HEADERS + assert "extra_headers" not in kwargs assert "http_options" not in kwargs