diff --git a/openeo/rest/auth/oidc.py b/openeo/rest/auth/oidc.py index daa184a93..7cd7df84f 100644 --- a/openeo/rest/auth/oidc.py +++ b/openeo/rest/auth/oidc.py @@ -256,6 +256,7 @@ def __init__( title: str = None, default_clients: Union[List[dict], None] = None, requests_session: Optional[requests.Session] = None, + authorization_parameters: Optional[dict] = None, ): # TODO: id and title are required in the openEO API spec. self.id = provider_id @@ -280,6 +281,7 @@ def __init__( self._scopes = {"openid"}.union(scopes or []).intersection(self._supported_scopes) log.debug(f"Scopes: provider supported {self._supported_scopes} & backend desired {scopes} -> {self._scopes}") self.default_clients = default_clients + self.authorization_parameters = authorization_parameters or {} @classmethod def from_dict(cls, data: dict) -> OidcProviderInfo: @@ -289,6 +291,7 @@ def from_dict(cls, data: dict) -> OidcProviderInfo: issuer=data["issuer"], scopes=data.get("scopes"), default_clients=data.get("default_clients"), + authorization_parameters=data.get("authorization_parameters"), ) def get_scopes_string(self, request_refresh_token: bool = False) -> str: @@ -563,6 +566,7 @@ def _get_auth_code(self, request_refresh_token: bool = False) -> AuthCodeResult: "nonce": nonce, "code_challenge": pkce.code_challenge, "code_challenge_method": pkce.code_challenge_method, + **self._client_info.provider.authorization_parameters, } ), ) @@ -855,6 +859,7 @@ def _get_verification_info(self, request_refresh_token: bool = False) -> Verific if self._pkce: post_data["code_challenge"] = self._pkce.code_challenge post_data["code_challenge_method"] = self._pkce.code_challenge_method + post_data.update(self._client_info.provider.authorization_parameters) resp = self._requests.post(url=self._device_code_url, data=post_data) if resp.status_code != 200: raise OidcException( diff --git a/tests/rest/auth/test_oidc.py b/tests/rest/auth/test_oidc.py index e1bec7028..16d546f45 100644 --- a/tests/rest/auth/test_oidc.py +++ b/tests/rest/auth/test_oidc.py @@ -1,6 +1,7 @@ import logging import re import time +import urllib.parse from io import BytesIO from queue import Queue @@ -818,3 +819,95 @@ def post_token(request, context): assert tokens.access_token == "6cce5-t0k3n" assert len(adapter.request_history) == 2 + + +class TestOidcProviderInfoAuthorizationParameters: + """Tests for the authorization_parameters flag introduced in openEO API >= 1.3.0""" + + def test_from_dict_with_authorization_parameters(self, requests_mock): + requests_mock.get("https://authit.test/.well-known/openid-configuration", json={"scopes_supported": ["openid"]}) + data = { + "id": "google", + "title": "Google", + "issuer": "https://authit.test", + "scopes": ["openid"], + "authorization_parameters": {"access_type": "offline", "prompt": "consent"}, + } + info = OidcProviderInfo.from_dict(data) + assert info.authorization_parameters == {"access_type": "offline", "prompt": "consent"} + + def test_from_dict_without_authorization_parameters(self, requests_mock): + requests_mock.get("https://authit.test/.well-known/openid-configuration", json={"scopes_supported": ["openid"]}) + data = { + "id": "egi", + "title": "EGI", + "issuer": "https://authit.test", + } + info = OidcProviderInfo.from_dict(data) + assert info.authorization_parameters == {} + + def test_device_code_request_includes_authorization_parameters(self, requests_mock): + """Checks whether the authorization_parameters ends up in the device code POST body.""" + oidc_issuer = "https://authit.test" + requests_mock.get( + f"{oidc_issuer}/.well-known/openid-configuration", + json={ + "scopes_supported": ["openid"], + "device_authorization_endpoint": f"{oidc_issuer}/device_code", + "token_endpoint": f"{oidc_issuer}/token", + }, + ) + device_code_mock = requests_mock.post( + f"{oidc_issuer}/device_code", + json={ + "device_code": "d3v1c3", + "user_code": "US3R", + "verification_uri": f"{oidc_issuer}/dc", + "interval": 5, + }, + ) + provider = OidcProviderInfo( + issuer=oidc_issuer, + authorization_parameters={"access_type": "offline", "prompt": "consent"}, + ) + authenticator = OidcDeviceAuthenticator( + client_info=OidcClientInfo(client_id="myclient", provider=provider, client_secret="s3cr3t"), + ) + authenticator._get_verification_info() + + assert device_code_mock.call_count == 1 + post_body = urllib.parse.parse_qs(device_code_mock.last_request.text) + assert post_body["client_id"] == ["myclient"] + assert post_body["access_type"] == ["offline"] + assert post_body["prompt"] == ["consent"] + + def test_device_code_request_without_authorization_parameters(self, requests_mock): + """Verify no extra params when authorization_parameters is empty.""" + oidc_issuer = "https://authit.test" + requests_mock.get( + f"{oidc_issuer}/.well-known/openid-configuration", + json={ + "scopes_supported": ["openid"], + "device_authorization_endpoint": f"{oidc_issuer}/device_code", + "token_endpoint": f"{oidc_issuer}/token", + }, + ) + device_code_mock = requests_mock.post( + f"{oidc_issuer}/device_code", + json={ + "device_code": "d3v1c3", + "user_code": "US3R", + "verification_uri": f"{oidc_issuer}/dc", + "interval": 5, + }, + ) + provider = OidcProviderInfo(issuer=oidc_issuer) + authenticator = OidcDeviceAuthenticator( + client_info=OidcClientInfo(client_id="myclient", provider=provider, client_secret="s3cr3t"), + ) + authenticator._get_verification_info() + + post_body = urllib.parse.parse_qs(device_code_mock.last_request.text) + assert "access_type" not in post_body + assert "prompt" not in post_body + assert set(post_body.keys()) == {"client_id", "scope"}