diff --git a/mp_api/client/_server_utils.py b/mp_api/client/_server_utils.py index be17aa56..0951a172 100644 --- a/mp_api/client/_server_utils.py +++ b/mp_api/client/_server_utils.py @@ -1,33 +1,62 @@ -"""Define utilities needed by the MP web server.""" +"""Define flask-dependent utilities for the web server.""" + from __future__ import annotations +from typing import TYPE_CHECKING + try: - import flask + from flask import has_request_context as _has_request_context + from flask import request except ImportError: - from mp_api.client.core.exceptions import MPRestError + _has_request_context = None # type: ignore[assignment] + request = None # type: ignore[assignment] + +from mp_api.client.core.utils import validate_api_key - raise MPRestError("`flask` must be installed to use server utilities.") +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Any -import requests -from mp_api.client import MPRester -from mp_api.client.core.utils import validate_api_key +def has_request_context() -> bool: + """Determine if the current context is a request. + + Returns: + -------- + bool : True if in a request context + False if flask is not installed or not in a request context. + """ + return _has_request_context is not None and _has_request_context() + + +def get_request_headers() -> dict[str, Any]: + """Get the headers if operating in a request context. + + Returns: + -------- + dict of str to Any + Empty dict if flask is not installed, or not in a request context. + Request headers otherwise. + """ + return request.headers if has_request_context() else {} -SESSION = requests.Session() +def is_dev_env( + localhosts: Sequence[str] = ("localhost:", "127.0.0.1:", "0.0.0.0:") +) -> bool: + """Determine if current env is local/developmental or production. -def is_localhost() -> bool: - """Determine if current env is local or production. + Args: + localhosts (Sequence of str) : A set of host prefixes for checking + if the current environment is locally deployed. Returns: bool: True if the environment is locally hosted. """ return ( True - if not flask.has_request_context() - else flask.request.headers.get("Host", "").startswith( - ("localhost:", "127.0.0.1:", "0.0.0.0:") - ) + if not has_request_context() + else get_request_headers().get("Host", "").startswith(localhosts) ) @@ -37,7 +66,7 @@ def get_consumer() -> dict[str, str]: Returns: dict of str to str, the headers associated with the consumer """ - if not flask.has_request_context(): + if not has_request_context(): return {} names = [ @@ -48,7 +77,7 @@ def get_consumer() -> dict[str, str]: "X-Authenticated-Groups", # groups this user belongs to "X-Consumer-Groups", # same as X-Authenticated-Groups ] - headers = flask.request.headers + headers = get_request_headers() return {name: headers[name] for name in names if headers.get(name) is not None} @@ -65,39 +94,23 @@ def is_logged_in_user(consumer: dict[str, str] | None = None) -> bool: return bool(not c.get("X-Anonymous-Consumer") and c.get("X-Consumer-Id")) -def get_user_api_key(consumer: dict[str, str] | None = None) -> str | None: +def get_user_api_key( + api_key: str | None = None, consumer: dict[str, str] | None = None +) -> str | None: """Get the api key that belongs to the current user. If running on localhost, api key is obtained from the environment variable MP_API_KEY. Args: + api_key (str or None) : User API key consumer (dict of str to str, or None): Headers associated with the consumer Returns: str, the API key, or None if no API key could be identified. """ - c = consumer or get_consumer() - - if is_localhost(): - return validate_api_key() - elif is_logged_in_user(c): + if is_dev_env(): + return validate_api_key(api_key=api_key) + elif is_logged_in_user(c := consumer or get_consumer()): return c.get("X-Consumer-Custom-Id") return None - - -def get_rester(**kwargs) -> MPRester: - """Create MPRester with headers set for localhost and production compatibility. - - Args: - **kwargs : kwargs to pass to MPRester - - Returns: - MPRester - """ - if is_localhost(): - dev_api_key = get_user_api_key() - SESSION.headers["x-api-key"] = dev_api_key or "" - return MPRester(api_key=dev_api_key, session=SESSION, **kwargs) - - return MPRester(headers=get_consumer(), session=SESSION, **kwargs) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index c192ecc9..e2127b22 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -57,13 +57,6 @@ validate_ids, ) -try: - import flask - - _flask_is_installed = True -except ImportError: - _flask_is_installed = False - if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator from typing import Any @@ -1177,17 +1170,13 @@ def _submit_request_and_process( Returns: Tuple with data and total number of docs in matching the query in the database. """ - headers = None - if _flask_is_installed and flask.has_request_context(): - headers = flask.request.headers - try: response = self.session.get( url=url, verify=verify, params=params, timeout=timeout, - headers=headers if headers else self.headers, + headers=self.headers, ) except requests.exceptions.ConnectTimeout: raise MPRestError( diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index b7a4cebd..dae0284f 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -72,6 +72,11 @@ class MAPIClientSettings(BaseSettings): description="Angle tolerance for structure matching in degrees.", ) + LOG_FILE: Path = Field( + Path("~/.mprester.log.yaml").expanduser(), + description="Path for storing last accessed database version.", + ) + LOCAL_DATASET_CACHE: Path = Field( Path("~/mp_datasets").expanduser(), description="Target directory for downloading full datasets", diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index c04aac7f..146a4fe1 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -21,6 +21,7 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from requests import Session, get +from mp_api.client._server_utils import get_consumer, get_user_api_key, is_dev_env from mp_api.client.core import BaseRester from mp_api.client.core._oxygen_evolution import OxygenEvolution from mp_api.client.core.exceptions import ( @@ -32,7 +33,6 @@ from mp_api.client.core.utils import ( LazyImport, load_json, - validate_api_key, validate_endpoint, validate_ids, ) @@ -141,16 +141,18 @@ def __init__( force_renew: Option to overwrite existing local dataset **kwargs: access to legacy kwargs that may be in the process of being deprecated """ - self.api_key = validate_api_key(api_key) + self.api_key = get_user_api_key(api_key=api_key) self.endpoint = validate_endpoint(endpoint) - self.headers = headers or {} + self.headers = headers or get_consumer() self.session = session or BaseRester._create_session( api_key=self.api_key, include_user_agent=include_user_agent, headers=self.headers, ) + if is_dev_env(): + self.session.headers["x-api-key"] = self.api_key or "" self._include_user_agent = include_user_agent self.use_document_model = use_document_model self.mute_progress_bars = mute_progress_bars @@ -209,7 +211,7 @@ def __init__( ) if notify_db_version: - raise NotImplementedError("This has not yet been implemented.") + self._db_version_check() # Dynamically set rester attributes. # First, materials and molecules top level resters are set. @@ -296,6 +298,10 @@ def __dir__(self): + [r.split("/", 1)[0] for r in TOP_LEVEL_RESTERS if not r.startswith("_")] ) + def __repr__(self) -> str: + db_version = self.get_database_version() + return f"MPRester({'v' + db_version if db_version else 'unknown version'})" + def get_task_ids_associated_with_material_id( self, material_id: str, calc_types: list[CalcType] | None = None ) -> list[str]: @@ -367,7 +373,7 @@ def get_database_version(self) -> str | None: where "_DD" may be optional. An additional numerical suffix might be added if multiple releases happen on the same day. - Returns: database version as a string + Returns: database version as a string if accessible, None otherwise """ if (get_resp := get(url=self.endpoint + "heartbeat")).status_code == 403: _emit_status_warning() @@ -1636,3 +1642,31 @@ def get_oxygen_evolution( phase_diagram, unique_composition, ) + + def _db_version_check(self) -> None: + """Check if the database version has drifted.""" + import yaml # type: ignore[import-untyped] + + db_version = self.get_database_version() + old_db_version = None + if MAPI_CLIENT_SETTINGS.LOG_FILE.exists(): + old_db_version = ( + yaml.safe_load(MAPI_CLIENT_SETTINGS.LOG_FILE.read_text()) or {} + ).get("MAPI_DB_VERSION", None) + + # Handle legacy pymatgen behavior + if not isinstance(old_db_version, str): + old_db_version = None + + if old_db_version != db_version: + MAPI_CLIENT_SETTINGS.LOG_FILE.write_text( + yaml.safe_dump({"MAPI_DB_VERSION": db_version}) + ) + + if old_db_version: + warnings.warn( + "Materials Project database version has changed " + f"from v{old_db_version} to v{db_version}.", + category=MPRestWarning, + stacklevel=2, + ) diff --git a/mp_api/client/routes/_server.py b/mp_api/client/routes/_server.py index fec9d2f6..4bced809 100644 --- a/mp_api/client/routes/_server.py +++ b/mp_api/client/routes/_server.py @@ -6,12 +6,13 @@ from emmet.core._general_store import GeneralStoreDoc from emmet.core._messages import MessagesDoc, MessageType -from emmet.core._user_settings import UserSettingsDoc +from emmet.core._user_settings import UserSettings, UserSettingsDoc from mp_api.client.core import BaseRester if TYPE_CHECKING: from datetime import datetime + from typing import Any class GeneralStoreRester(BaseRester): # pragma: no cover @@ -133,7 +134,9 @@ class UserSettingsRester(BaseRester): # pragma: no cover primary_key = "consumer_id" use_document_model = False - def create_user_settings(self, consumer_id, settings): + def create_user_settings( + self, consumer_id: str, settings: dict[str, Any] + ) -> dict[str, Any]: """Create user settings. Args: @@ -143,74 +146,76 @@ def create_user_settings(self, consumer_id, settings): Returns: Dictionary with consumer_id and write status. """ - return self._post_resource( + return self._post_resource( # type: ignore[return-value] body=settings, params={"consumer_id": consumer_id} ).get("data") - def patch_user_settings(self, consumer_id, settings): # pragma: no cover + def patch_user_settings( + self, consumer_id: str, settings: dict[str, Any] + ) -> UserSettingsDoc: """Patch user settings. Args: - consumer_id: Consumer ID for the user + consumer_id (str): Consumer ID for the user settings: Dictionary with user settings Returns: - Dictionary with consumer_id and write status. + UserSettingsDoc with consumer_id and write status. Raises: MPRestError. """ - body = dict() - valid_fields = [ - "institution", - "sector", - "job_role", - "is_email_subscribed", - "agreed_terms", - "message_last_read", - ] - for key in settings: - if key not in valid_fields: - raise ValueError( - f"Invalid setting key {key}. Must be one of {valid_fields}" - ) - body[f"settings.{key}"] = settings[key] - - return self._patch_resource(body=body, params={"consumer_id": consumer_id}).get( - "data" - ) + if ( + len( + invalid_keys := [ + key for key in settings if key not in UserSettings.model_fields + ] + ) + > 0 + ): + raise ValueError( + f"Invalid setting key(s): {', '.join(invalid_keys)}. " + f"Valid keys: {', '.join(UserSettings.model_fields)}" + ) + + return self._patch_resource( # type: ignore[return-value] + body={f"settings.{key}": v for key, v in settings.items()}, + params={"consumer_id": consumer_id}, + ).get("data") - def patch_user_time_settings(self, consumer_id, time): # pragma: no cover + def patch_user_time_settings( + self, consumer_id: str, time: datetime + ) -> UserSettingsDoc: """Set user settings last_read_message field. Args: - consumer_id: Consumer ID for the user - time: utc datetime object for when the user last see messages + consumer_id (str): Consumer ID for the user + time (datetime): UTC datetime object for when the user last see messages Returns: - Dictionary with consumer_id and write status. - + UserSettingsDoc Raises: MPRestError. """ - return self._patch_resource( + return self._patch_resource( # type: ignore[return-value] body={"settings.message_last_read": time.isoformat()}, params={"consumer_id": consumer_id}, ).get("data") - def get_user_settings(self, consumer_id, fields): # pragma: no cover + def get_user_settings( + self, consumer_id: str, fields: list[str] + ) -> list[UserSettingsDoc]: """Get user settings. Args: - consumer_id: Consumer ID for the user - fields: List of fields to project + consumer_id (str): Consumer ID for the user + fields (list of str): List of fields to project Returns: - Dictionary with consumer_id and settings. - + list of UserSettingsDoc, with consumer_id and settings. Raises: MPRestError. """ - return self._query_resource( + return self._query_resource( # type: ignore[return-value] suburl=f"{consumer_id}", fields=fields, num_chunks=1, chunk_size=1 ).get("data") diff --git a/tests/client/core/test_utils.py b/tests/client/core/test_utils.py index c8916a3c..cf5c9f59 100644 --- a/tests/client/core/test_utils.py +++ b/tests/client/core/test_utils.py @@ -142,7 +142,7 @@ def test_api_key_validation(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(pymatgen.core, "SETTINGS", non_api_key_settings) with pytest.raises(MPRestError, match="32 characters"): - validate_api_key("invalid_key") + validate_api_key(api_key="invalid_key") with pytest.warns(MPRestWarning, match="No API key found"): validate_api_key() @@ -150,7 +150,7 @@ def test_api_key_validation(monkeypatch: pytest.MonkeyPatch): junk_api_key = "a" * 32 monkeypatch.setenv("MP_API_KEY", junk_api_key) assert validate_api_key() == junk_api_key - assert validate_api_key(junk_api_key) == junk_api_key + assert validate_api_key(api_key=junk_api_key) == junk_api_key other_junk_api_key = "b" * 32 monkeypatch.setattr( diff --git a/tests/client/test_mprester.py b/tests/client/test_mprester.py index 6b89577f..38efd6e0 100644 --- a/tests/client/test_mprester.py +++ b/tests/client/test_mprester.py @@ -686,9 +686,25 @@ def test_nomad_integration(self, mpr): for url in nomad_urls ) - def test_warnings_exceptions(self, monkeypatch: pytest.MonkeyPatch): + def test_db_warning(self, monkeypatch: pytest.MonkeyPatch): + from pathlib import Path + import yaml from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS + with NamedTemporaryFile(suffix=".yaml") as tmp_log: + monkeypatch.setattr(MAPI_CLIENT_SETTINGS, "LOG_FILE", Path(tmp_log.name)) + + with MPRester(notify_db_version=True) as mpr: + db_version = mpr.get_database_version() + + parsed_db_ver = yaml.safe_load(Path(tmp_log.name).read_text()).get( + "MAPI_DB_VERSION" + ) + assert parsed_db_ver == db_version + assert isinstance(parsed_db_ver, str) + + def test_warnings_exceptions(self): + # Generic warnings/exceptions tests, nothji with pytest.warns(MPRestWarning, match="Ignoring `monty_decode`"): MPRester(monty_decode=False) @@ -710,6 +726,10 @@ def test_warnings_exceptions(self, monkeypatch: pytest.MonkeyPatch): ): getattr(mpr, attr, None) + def test_min_emmet_warning(self, monkeypatch: pytest.MonkeyPatch): + from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS + + with MPRester() as mpr: emmet_ver = mpr.get_emmet_version(mpr.endpoint) monkeypatch.setattr( MAPI_CLIENT_SETTINGS, "MIN_EMMET_VERSION", f"{emmet_ver.major + 1}.0.0"