Skip to content

Commit 69997cd

Browse files
MarlzRanaxuanyang15
authored andcommitted
fix: oauth refresh not triggered on token expiry
Merge #3767 Co-authored-by: Xuan Yang <[email protected]> COPYBARA_INTEGRATE_REVIEW=#3767 from MarlzRana:marlzrana/fix-oauth-refresh-not-triggered-on-token-expiry 2dae391 PiperOrigin-RevId: 843756363
1 parent 5c4bae7 commit 69997cd

File tree

5 files changed

+75
-59
lines changed

5 files changed

+75
-59
lines changed

src/google/adk/auth/auth_handler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ async def exchange_auth_token(
4848
self,
4949
) -> AuthCredential:
5050
exchanger = OAuth2CredentialExchanger()
51-
return await exchanger.exchange(
51+
exchange_result = await exchanger.exchange(
5252
self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme
5353
)
54+
return exchange_result.credential
5455

5556
async def parse_and_store_auth_response(self, state: State) -> None:
5657

src/google/adk/auth/credential_manager.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .auth_schemes import OpenIdConnectWithConfig
3030
from .auth_tool import AuthConfig
3131
from .exchanger.base_credential_exchanger import BaseCredentialExchanger
32+
from .exchanger.base_credential_exchanger import ExchangeResult
3233
from .exchanger.credential_exchanger_registry import CredentialExchangerRegistry
3334
from .oauth2_discovery import OAuth2DiscoveryManager
3435
from .refresher.credential_refresher_registry import CredentialRefresherRegistry
@@ -214,15 +215,17 @@ async def _exchange_credential(
214215
return credential, False
215216

216217
if isinstance(exchanger, ServiceAccountCredentialExchanger):
217-
exchanged_credential = exchanger.exchange_credential(
218-
self._auth_config.auth_scheme, credential
219-
)
220-
else:
221-
exchanged_credential = await exchanger.exchange(
222-
credential, self._auth_config.auth_scheme
218+
return (
219+
exchanger.exchange_credential(
220+
self._auth_config.auth_scheme, credential
221+
),
222+
True,
223223
)
224224

225-
return exchanged_credential, True
225+
exchange_result = await exchanger.exchange(
226+
credential, self._auth_config.auth_scheme
227+
)
228+
return exchange_result.credential, exchange_result.was_exchanged
226229

227230
async def _refresh_credential(
228231
self, credential: AuthCredential

src/google/adk/auth/exchanger/base_credential_exchanger.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import annotations
1818

1919
import abc
20+
from typing import NamedTuple
2021
from typing import Optional
2122

2223
from ...utils.feature_decorator import experimental
@@ -28,6 +29,11 @@ class CredentialExchangeError(Exception):
2829
"""Base exception for credential exchange errors."""
2930

3031

32+
class ExchangeResult(NamedTuple):
33+
credential: AuthCredential
34+
was_exchanged: bool
35+
36+
3137
@experimental
3238
class BaseCredentialExchanger(abc.ABC):
3339
"""Base interface for credential exchangers.
@@ -41,15 +47,17 @@ async def exchange(
4147
self,
4248
auth_credential: AuthCredential,
4349
auth_scheme: Optional[AuthScheme] = None,
44-
) -> AuthCredential:
50+
) -> ExchangeResult:
4551
"""Exchange credential if needed.
4652
4753
Args:
4854
auth_credential: The credential to exchange.
49-
auth_scheme: The authentication scheme (optional, some exchangers don't need it).
55+
auth_scheme: The authentication scheme (optional, some exchangers don't
56+
need it).
5057
5158
Returns:
52-
The exchanged credential.
59+
An ExchangeResult object containing the exchanged credential and a
60+
boolean indicating whether the credential was exchanged.
5361
5462
Raises:
5563
CredentialExchangeError: If credential exchange fails.

src/google/adk/auth/exchanger/oauth2_credential_exchanger.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from .base_credential_exchanger import BaseCredentialExchanger
3333
from .base_credential_exchanger import CredentialExchangeError
34+
from .base_credential_exchanger import ExchangeResult
3435

3536
try:
3637
from authlib.integrations.requests_client import OAuth2Session
@@ -51,7 +52,7 @@ async def exchange(
5152
self,
5253
auth_credential: AuthCredential,
5354
auth_scheme: Optional[AuthScheme] = None,
54-
) -> AuthCredential:
55+
) -> ExchangeResult:
5556
"""Exchange OAuth2 credential from authorization response.
5657
5758
if credential exchange failed, the original credential will be returned.
@@ -61,7 +62,8 @@ async def exchange(
6162
auth_scheme: The OAuth2 authentication scheme.
6263
6364
Returns:
64-
The exchanged credential with access token.
65+
An ExchangeResult object containing the exchanged credential and a
66+
boolean indicating whether the credential was exchanged.
6567
6668
Raises:
6769
CredentialExchangeError: If auth_scheme is missing.
@@ -79,10 +81,10 @@ async def exchange(
7981
logger.warning(
8082
"authlib is not available, skipping OAuth2 credential exchange."
8183
)
82-
return auth_credential
84+
return ExchangeResult(auth_credential, False)
8385

8486
if auth_credential.oauth2 and auth_credential.oauth2.access_token:
85-
return auth_credential
87+
return ExchangeResult(auth_credential, False)
8688

8789
# Determine grant type from auth_scheme
8890
grant_type = self._determine_grant_type(auth_scheme)
@@ -97,7 +99,7 @@ async def exchange(
9799
)
98100
else:
99101
logger.warning("Unsupported OAuth2 grant type: %s", grant_type)
100-
return auth_credential
102+
return ExchangeResult(auth_credential, False)
101103

102104
def _determine_grant_type(
103105
self, auth_scheme: AuthScheme
@@ -129,22 +131,23 @@ async def _exchange_client_credentials(
129131
self,
130132
auth_credential: AuthCredential,
131133
auth_scheme: AuthScheme,
132-
) -> AuthCredential:
134+
) -> ExchangeResult:
133135
"""Exchange client credentials for access token.
134136
135137
Args:
136138
auth_credential: The OAuth2 credential to exchange.
137139
auth_scheme: The OAuth2 authentication scheme.
138140
139141
Returns:
140-
The credential with access token.
142+
An ExchangeResult object containing the exchanged credential and a
143+
boolean indicating whether the credential was exchanged.
141144
"""
142145
client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential)
143146
if not client:
144147
logger.warning(
145148
"Could not create OAuth2 session for client credentials exchange"
146149
)
147-
return auth_credential
150+
return ExchangeResult(auth_credential, False)
148151

149152
try:
150153
tokens = client.fetch_token(
@@ -155,13 +158,13 @@ async def _exchange_client_credentials(
155158
logger.debug("Successfully exchanged client credentials for access token")
156159
except Exception as e:
157160
logger.error("Failed to exchange client credentials: %s", e)
158-
return auth_credential
161+
return ExchangeResult(auth_credential, False)
159162

160-
return auth_credential
163+
return ExchangeResult(auth_credential, True)
161164

162165
def _normalize_auth_uri(self, auth_uri: str | None) -> str | None:
163-
# Authlib currently used a simplified token check by simply scanning hash existence,
164-
# yet itself might sometimes add extraneous hashes.
166+
# Authlib currently used a simplified token check by simply scanning hash
167+
# existence, yet itself might sometimes add extraneous hashes.
165168
# Drop trailing empty hash if seen.
166169
if auth_uri and auth_uri.endswith("#"):
167170
return auth_uri[:-1]
@@ -171,22 +174,23 @@ async def _exchange_authorization_code(
171174
self,
172175
auth_credential: AuthCredential,
173176
auth_scheme: AuthScheme,
174-
) -> AuthCredential:
177+
) -> ExchangeResult:
175178
"""Exchange authorization code for access token.
176179
177180
Args:
178181
auth_credential: The OAuth2 credential to exchange.
179182
auth_scheme: The OAuth2 authentication scheme.
180183
181184
Returns:
182-
The credential with access token.
185+
An ExchangeResult object containing the exchanged credential and a
186+
boolean indicating whether the credential was exchanged.
183187
"""
184188
client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential)
185189
if not client:
186190
logger.warning(
187191
"Could not create OAuth2 session for authorization code exchange"
188192
)
189-
return auth_credential
193+
return ExchangeResult(auth_credential, False)
190194

191195
try:
192196
tokens = client.fetch_token(
@@ -202,6 +206,6 @@ async def _exchange_authorization_code(
202206
logger.debug("Successfully exchanged authorization code for access token")
203207
except Exception as e:
204208
logger.error("Failed to exchange authorization code: %s", e)
205-
return auth_credential
209+
return ExchangeResult(auth_credential, False)
206210

207-
return auth_credential
211+
return ExchangeResult(auth_credential, True)

tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
class TestOAuth2CredentialExchanger:
3434
"""Test suite for OAuth2CredentialExchanger."""
3535

36-
@pytest.mark.asyncio
3736
async def test_exchange_with_existing_token(self):
3837
"""Test exchange method when access token already exists."""
3938
scheme = OpenIdConnectWithConfig(
@@ -55,14 +54,14 @@ async def test_exchange_with_existing_token(self):
5554
)
5655

5756
exchanger = OAuth2CredentialExchanger()
58-
result = await exchanger.exchange(credential, scheme)
57+
exchange_result = await exchanger.exchange(credential, scheme)
5958

6059
# Should return the same credential since access token already exists
61-
assert result == credential
62-
assert result.oauth2.access_token == "existing_token"
60+
assert exchange_result.credential == credential
61+
assert exchange_result.credential.oauth2.access_token == "existing_token"
62+
assert not exchange_result.was_exchanged
6363

6464
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
65-
@pytest.mark.asyncio
6665
async def test_exchange_success(self, mock_oauth2_session):
6766
"""Test successful token exchange."""
6867
# Setup mock
@@ -96,14 +95,16 @@ async def test_exchange_success(self, mock_oauth2_session):
9695
)
9796

9897
exchanger = OAuth2CredentialExchanger()
99-
result = await exchanger.exchange(credential, scheme)
98+
exchange_result = await exchanger.exchange(credential, scheme)
10099

101100
# Verify token exchange was successful
102-
assert result.oauth2.access_token == "new_access_token"
103-
assert result.oauth2.refresh_token == "new_refresh_token"
101+
assert exchange_result.credential.oauth2.access_token == "new_access_token"
102+
assert (
103+
exchange_result.credential.oauth2.refresh_token == "new_refresh_token"
104+
)
105+
assert exchange_result.was_exchanged
104106
mock_client.fetch_token.assert_called_once()
105107

106-
@pytest.mark.asyncio
107108
async def test_exchange_missing_auth_scheme(self):
108109
"""Test exchange with missing auth_scheme raises ValueError."""
109110
credential = AuthCredential(
@@ -122,7 +123,6 @@ async def test_exchange_missing_auth_scheme(self):
122123
assert "auth_scheme is required" in str(e)
123124

124125
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
125-
@pytest.mark.asyncio
126126
async def test_exchange_no_session(self, mock_oauth2_session):
127127
"""Test exchange when OAuth2Session cannot be created."""
128128
# Mock to return None for create_oauth2_session
@@ -146,14 +146,14 @@ async def test_exchange_no_session(self, mock_oauth2_session):
146146
)
147147

148148
exchanger = OAuth2CredentialExchanger()
149-
result = await exchanger.exchange(credential, scheme)
149+
exchange_result = await exchanger.exchange(credential, scheme)
150150

151151
# Should return original credential when session creation fails
152-
assert result == credential
153-
assert result.oauth2.access_token is None
152+
assert exchange_result.credential == credential
153+
assert exchange_result.credential.oauth2.access_token is None
154+
assert not exchange_result.was_exchanged
154155

155156
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
156-
@pytest.mark.asyncio
157157
async def test_exchange_fetch_token_failure(self, mock_oauth2_session):
158158
"""Test exchange when fetch_token fails."""
159159
# Setup mock to raise exception during fetch_token
@@ -181,14 +181,14 @@ async def test_exchange_fetch_token_failure(self, mock_oauth2_session):
181181
)
182182

183183
exchanger = OAuth2CredentialExchanger()
184-
result = await exchanger.exchange(credential, scheme)
184+
exchange_result = await exchanger.exchange(credential, scheme)
185185

186186
# Should return original credential when fetch_token fails
187-
assert result == credential
188-
assert result.oauth2.access_token is None
187+
assert exchange_result.credential == credential
188+
assert exchange_result.credential.oauth2.access_token is None
189+
assert not exchange_result.was_exchanged
189190
mock_client.fetch_token.assert_called_once()
190191

191-
@pytest.mark.asyncio
192192
async def test_exchange_authlib_not_available(self):
193193
"""Test exchange when authlib is not available."""
194194
scheme = OpenIdConnectWithConfig(
@@ -217,14 +217,14 @@ async def test_exchange_authlib_not_available(self):
217217
"google.adk.auth.exchanger.oauth2_credential_exchanger.AUTHLIB_AVAILABLE",
218218
False,
219219
):
220-
result = await exchanger.exchange(credential, scheme)
220+
exchange_result = await exchanger.exchange(credential, scheme)
221221

222222
# Should return original credential when authlib is not available
223-
assert result == credential
224-
assert result.oauth2.access_token is None
223+
assert exchange_result.credential == credential
224+
assert exchange_result.credential.oauth2.access_token is None
225+
assert not exchange_result.was_exchanged
225226

226227
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
227-
@pytest.mark.asyncio
228228
async def test_exchange_client_credentials_success(self, mock_oauth2_session):
229229
"""Test successful client credentials exchange."""
230230
# Setup mock
@@ -255,17 +255,19 @@ async def test_exchange_client_credentials_success(self, mock_oauth2_session):
255255
)
256256

257257
exchanger = OAuth2CredentialExchanger()
258-
result = await exchanger.exchange(credential, scheme)
258+
exchange_result = await exchanger.exchange(credential, scheme)
259259

260260
# Verify client credentials exchange was successful
261-
assert result.oauth2.access_token == "client_access_token"
261+
assert (
262+
exchange_result.credential.oauth2.access_token == "client_access_token"
263+
)
264+
assert exchange_result.was_exchanged
262265
mock_client.fetch_token.assert_called_once_with(
263266
"https://example.com/token",
264267
grant_type="client_credentials",
265268
)
266269

267270
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
268-
@pytest.mark.asyncio
269271
async def test_exchange_client_credentials_failure(self, mock_oauth2_session):
270272
"""Test client credentials exchange failure."""
271273
# Setup mock to raise exception during fetch_token
@@ -292,15 +294,15 @@ async def test_exchange_client_credentials_failure(self, mock_oauth2_session):
292294
)
293295

294296
exchanger = OAuth2CredentialExchanger()
295-
result = await exchanger.exchange(credential, scheme)
297+
exchange_result = await exchanger.exchange(credential, scheme)
296298

297299
# Should return original credential when client credentials exchange fails
298-
assert result == credential
299-
assert result.oauth2.access_token is None
300+
assert exchange_result.credential == credential
301+
assert exchange_result.credential.oauth2.access_token is None
302+
assert not exchange_result.was_exchanged
300303
mock_client.fetch_token.assert_called_once()
301304

302305
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
303-
@pytest.mark.asyncio
304306
async def test_exchange_normalize_uri(self, mock_oauth2_session):
305307
"""Test exchange method normalizes auth_response_uri."""
306308
mock_client = Mock()
@@ -344,7 +346,6 @@ async def test_exchange_normalize_uri(self, mock_oauth2_session):
344346
client_id="test_client_id",
345347
)
346348

347-
@pytest.mark.asyncio
348349
async def test_determine_grant_type_client_credentials(self):
349350
"""Test grant type determination for client credentials."""
350351
flows = OAuthFlows(
@@ -361,7 +362,6 @@ async def test_determine_grant_type_client_credentials(self):
361362

362363
assert grant_type == OAuthGrantType.CLIENT_CREDENTIALS
363364

364-
@pytest.mark.asyncio
365365
async def test_determine_grant_type_openid_connect(self):
366366
"""Test grant type determination for OpenID Connect (defaults to auth code)."""
367367
scheme = OpenIdConnectWithConfig(

0 commit comments

Comments
 (0)