diff --git a/.github/workflows/_test-integrations.yml b/.github/workflows/_test-integrations.yml index 8b8bfa3c..70b7b0e9 100644 --- a/.github/workflows/_test-integrations.yml +++ b/.github/workflows/_test-integrations.yml @@ -49,6 +49,7 @@ jobs: MINDEE_V2_API_KEY: ${{ secrets.MINDEE_V2_SE_TESTS_API_KEY }} MINDEE_V2_FINDOC_MODEL_ID: ${{ secrets.MINDEE_V2_SE_TESTS_FINDOC_MODEL_ID }} MINDEE_V2_SE_TESTS_BLANK_PDF_URL: ${{ secrets.MINDEE_V2_SE_TESTS_BLANK_PDF_URL }} + MINDEE_V2_SE_TESTS_SPLIT_MODEL_ID: ${{ secrets.MINDEE_V2_SE_TESTS_SPLIT_MODEL_ID }} run: | pytest --cov mindee -m integration diff --git a/mindee/client.py b/mindee/client.py index 6b8d3ba1..5377314c 100644 --- a/mindee/client.py +++ b/mindee/client.py @@ -353,7 +353,7 @@ def enqueue_and_parse( # pylint: disable=too-many-locals if poll_results.job.status == "failed": raise MindeeError("Parsing failed for job {poll_results.job.id}") logger.debug( - "Polling server for parsing result with job id: %s", queue_result.job.id + "Polling server for product result with job id: %s", queue_result.job.id ) retry_counter += 1 sleep(delay_sec) diff --git a/mindee/client_v2.py b/mindee/client_v2.py index 6819b2cf..3a7ad300 100644 --- a/mindee/client_v2.py +++ b/mindee/client_v2.py @@ -1,10 +1,11 @@ +import warnings from time import sleep -from typing import Optional, Union +from typing import Optional, Union, Type, TypeVar from mindee.client_mixin import ClientMixin from mindee.error.mindee_error import MindeeError from mindee.error.mindee_http_error_v2 import handle_error_v2 -from mindee.input import UrlInputSource +from mindee.input import UrlInputSource, BaseParameters from mindee.input.inference_parameters import InferenceParameters from mindee.input.polling_options import PollingOptions from mindee.input.sources.local_input_source import LocalInputSource @@ -15,9 +16,12 @@ is_valid_post_response, ) from mindee.parsing.v2.common_response import CommonStatus +from mindee.v2.parsing.inference.base_response import BaseResponse from mindee.parsing.v2.inference_response import InferenceResponse from mindee.parsing.v2.job_response import JobResponse +TypeBaseInferenceResponse = TypeVar("TypeBaseInferenceResponse", bound=BaseResponse) + class ClientV2(ClientMixin): """ @@ -41,20 +45,35 @@ def __init__(self, api_key: Optional[str] = None) -> None: def enqueue_inference( self, input_source: Union[LocalInputSource, UrlInputSource], - params: InferenceParameters, + params: BaseParameters, + disable_redundant_warnings: bool = False, + ) -> JobResponse: + """[Deprecated] Use `enqueue` instead.""" + if not disable_redundant_warnings: + warnings.warn( + "enqueue_inference is deprecated; use enqueue instead", + DeprecationWarning, + stacklevel=2, + ) + return self.enqueue(input_source, params) + + def enqueue( + self, + input_source: Union[LocalInputSource, UrlInputSource], + params: BaseParameters, ) -> JobResponse: """ Enqueues a document to a given model. :param input_source: The document/source file to use. Can be local or remote. - :param params: Parameters to set when sending a file. + :param slug: Slug for the endpoint. + :return: A valid inference response. """ logger.debug("Enqueuing inference using model: %s", params.model_id) - response = self.mindee_api.req_post_inference_enqueue( - input_source=input_source, params=params + input_source=input_source, params=params, slug=params.get_enqueue_slug() ) dict_response = response.json() @@ -79,34 +98,57 @@ def get_job(self, job_id: str) -> JobResponse: dict_response = response.json() return JobResponse(dict_response) - def get_inference(self, inference_id: str) -> InferenceResponse: + def get_inference( + self, + inference_id: str, + response_type: Type[BaseResponse] = InferenceResponse, + disable_redundant_warnings: bool = False, + ) -> BaseResponse: + """[Deprecated] Use `get_result` instead.""" + if not disable_redundant_warnings: + warnings.warn( + "get_inference is deprecated; use get_result instead", + DeprecationWarning, + stacklevel=2, + ) + return self.get_result(inference_id, response_type) + + def get_result( + self, + inference_id: str, + response_type: Type[BaseResponse] = InferenceResponse, + ) -> BaseResponse: """ Get the result of an inference that was previously enqueued. The inference will only be available after it has finished processing. :param inference_id: UUID of the inference to retrieve. + :param response_type: Class of the product to instantiate. :return: An inference response. """ logger.debug("Fetching inference: %s", inference_id) - response = self.mindee_api.req_get_inference(inference_id) + response = self.mindee_api.req_get_inference( + inference_id, response_type.get_result_slug() + ) if not is_valid_get_response(response): handle_error_v2(response.json()) dict_response = response.json() - return InferenceResponse(dict_response) + return response_type(dict_response) - def enqueue_and_get_inference( + def _enqueue_and_get( self, input_source: Union[LocalInputSource, UrlInputSource], - params: InferenceParameters, - ) -> InferenceResponse: + params: BaseParameters, + response_type: Optional[Type[BaseResponse]] = InferenceResponse, + ) -> BaseResponse: """ Enqueues to an asynchronous endpoint and automatically polls for a response. :param input_source: The document/source file to use. Can be local or remote. - :param params: Parameters to set when sending a file. + :param response_type: The product class to use for the response object. :return: A valid inference response. """ @@ -117,9 +159,9 @@ def enqueue_and_get_inference( params.polling_options.delay_sec, params.polling_options.max_retries, ) - enqueue_response = self.enqueue_inference(input_source, params) + enqueue_response = self.enqueue_inference(input_source, params, True) logger.debug( - "Successfully enqueued inference with job id: %s", enqueue_response.job.id + "Successfully enqueued document with job id: %s", enqueue_response.job.id ) sleep(params.polling_options.initial_delay_sec) try_counter = 0 @@ -134,8 +176,51 @@ def enqueue_and_get_inference( f"Parsing failed for job {job_response.job.id}: {detail}" ) if job_response.job.status == CommonStatus.PROCESSED.value: - return self.get_inference(job_response.job.id) + result = self.get_inference( + job_response.job.id, response_type or InferenceResponse, True + ) + return result try_counter += 1 sleep(params.polling_options.delay_sec) raise MindeeError(f"Couldn't retrieve document after {try_counter + 1} tries.") + + def enqueue_and_get_inference( + self, + input_source: Union[LocalInputSource, UrlInputSource], + params: InferenceParameters, + ) -> InferenceResponse: + """[Deprecated] Use `enqueue_and_get_result` instead.""" + warnings.warn( + "enqueue_and_get_inference is deprecated; use enqueue_and_get_result", + DeprecationWarning, + stacklevel=2, + ) + response = self._enqueue_and_get(input_source, params) + assert isinstance(response, InferenceResponse), ( + f'Invalid response type "{type(response)}"' + ) + return response + + def enqueue_and_get_result( + self, + response_type: Type[TypeBaseInferenceResponse], + input_source: Union[LocalInputSource, UrlInputSource], + params: BaseParameters, + ) -> TypeBaseInferenceResponse: + """ + Enqueues to an asynchronous endpoint and automatically polls for a response. + + :param input_source: The document/source file to use. Can be local or remote. + + :param params: Parameters to set when sending a file. + + :param response_type: The product class to use for the response object. + + :return: A valid inference response. + """ + response = self._enqueue_and_get(input_source, params, response_type) + assert isinstance(response, response_type), ( + f'Invalid response type "{type(response)}"' + ) + return response diff --git a/mindee/error/mindee_http_error_v2.py b/mindee/error/mindee_http_error_v2.py index 99ba40da..a6be90f3 100644 --- a/mindee/error/mindee_http_error_v2.py +++ b/mindee/error/mindee_http_error_v2.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import List, Optional from mindee.parsing.common.string_dict import StringDict from mindee.parsing.v2 import ErrorItem, ErrorResponse @@ -18,7 +18,7 @@ def __init__(self, response: ErrorResponse) -> None: self.title = response.title self.code = response.code self.detail = response.detail - self.errors: list[ErrorItem] = response.errors + self.errors: List[ErrorItem] = response.errors super().__init__( f"HTTP {self.status} - {self.title} :: {self.code} - {self.detail}" ) diff --git a/mindee/input/__init__.py b/mindee/input/__init__.py index 9ed79985..31973802 100644 --- a/mindee/input/__init__.py +++ b/mindee/input/__init__.py @@ -1,4 +1,7 @@ from mindee.input.local_response import LocalResponse +from mindee.input.base_parameters import BaseParameters +from mindee.input.inference_parameters import InferenceParameters +from mindee.v2.product.split.split_parameters import SplitParameters from mindee.input.page_options import PageOptions from mindee.input.polling_options import PollingOptions from mindee.input.sources.base_64_input import Base64Input @@ -11,15 +14,18 @@ from mindee.input.workflow_options import WorkflowOptions __all__ = [ + "Base64Input", + "BaseParameters", + "BytesInput", + "FileInput", "InputType", + "InferenceParameters", "LocalInputSource", - "UrlInputSource", + "LocalResponse", + "PageOptions", "PathInput", - "FileInput", - "Base64Input", - "BytesInput", - "WorkflowOptions", "PollingOptions", - "PageOptions", - "LocalResponse", + "UrlInputSource", + "SplitParameters", + "WorkflowOptions", ] diff --git a/mindee/input/base_parameters.py b/mindee/input/base_parameters.py new file mode 100644 index 00000000..af863052 --- /dev/null +++ b/mindee/input/base_parameters.py @@ -0,0 +1,44 @@ +from abc import ABC +from dataclasses import dataclass, field +from typing import Dict, Optional, List, Union + +from mindee.input.polling_options import PollingOptions + + +@dataclass +class BaseParameters(ABC): + """Base class for parameters accepted by all V2 endpoints.""" + + _slug: str = field(init=False) + """Slug of the endpoint.""" + + model_id: str + """ID of the model, required.""" + alias: Optional[str] = None + """Use an alias to link the file to your own DB. If empty, no alias will be used.""" + webhook_ids: Optional[List[str]] = None + """IDs of webhooks to propagate the API response to.""" + polling_options: Optional[PollingOptions] = None + """Options for polling. Set only if having timeout issues.""" + close_file: bool = True + """Whether to close the file after product.""" + + def get_config(self) -> Dict[str, Union[str, List[str]]]: + """ + Return the parameters as a config dictionary. + + :return: A dict of parameters. + """ + data: Dict[str, Union[str, List[str]]] = { + "model_id": self.model_id, + } + if self.alias is not None: + data["alias"] = self.alias + if self.webhook_ids and len(self.webhook_ids) > 0: + data["webhook_ids"] = self.webhook_ids + return data + + @classmethod + def get_enqueue_slug(cls) -> str: + """Getter for the enqueue slug.""" + return cls._slug diff --git a/mindee/input/inference_parameters.py b/mindee/input/inference_parameters.py index 6d4e01fa..1554da31 100644 --- a/mindee/input/inference_parameters.py +++ b/mindee/input/inference_parameters.py @@ -1,8 +1,8 @@ import json -from dataclasses import dataclass, asdict -from typing import List, Optional, Union +from dataclasses import dataclass, asdict, field +from typing import Dict, List, Optional, Union -from mindee.input.polling_options import PollingOptions +from mindee.input.base_parameters import BaseParameters @dataclass @@ -44,7 +44,7 @@ class DataSchemaField(StringDataClass): guidelines: Optional[str] = None """Optional extraction guidelines.""" nested_fields: Optional[dict] = None - """Subfields when type is `nested_object`. Leave empty for other types""" + """Subfields when type is `nested_object`. Leave empty for other types.""" @dataclass @@ -78,11 +78,12 @@ def __post_init__(self) -> None: @dataclass -class InferenceParameters: +class InferenceParameters(BaseParameters): """Inference parameters to set when sending a file.""" - model_id: str - """ID of the model, required.""" + _slug: str = field(init=False, default="inferences") + """Slug of the endpoint.""" + rag: Optional[bool] = None """Enhance extraction accuracy with Retrieval-Augmented Generation.""" raw_text: Optional[bool] = None @@ -94,14 +95,6 @@ class InferenceParameters: Boost the precision and accuracy of all extractions. Calculate confidence scores for all fields, and fill their ``confidence`` attribute. """ - alias: Optional[str] = None - """Use an alias to link the file to your own DB. If empty, no alias will be used.""" - webhook_ids: Optional[List[str]] = None - """IDs of webhooks to propagate the API response to.""" - polling_options: Optional[PollingOptions] = None - """Options for polling. Set only if having timeout issues.""" - close_file: bool = True - """Whether to close the file after parsing.""" text_context: Optional[str] = None """ Additional text context used by the model during inference. @@ -118,3 +111,24 @@ def __post_init__(self): self.data_schema = DataSchema(**json.loads(self.data_schema)) elif isinstance(self.data_schema, dict): self.data_schema = DataSchema(**self.data_schema) + + def get_config(self) -> Dict[str, Union[str, List[str]]]: + """ + Return the parameters as a config dictionary. + + :return: A dict of parameters. + """ + data = super().get_config() + if self.data_schema is not None: + data["data_schema"] = str(self.data_schema) + if self.rag is not None: + data["rag"] = data["rag"] = str(self.rag).lower() + if self.raw_text is not None: + data["raw_text"] = data["raw_text"] = str(self.raw_text).lower() + if self.polygon is not None: + data["polygon"] = data["polygon"] = str(self.polygon).lower() + if self.confidence is not None: + data["confidence"] = data["confidence"] = str(self.confidence).lower() + if self.text_context is not None: + data["text_context"] = self.text_context + return data diff --git a/mindee/mindee_http/mindee_api_v2.py b/mindee/mindee_http/mindee_api_v2.py index 9990330c..1cb7a8b2 100644 --- a/mindee/mindee_http/mindee_api_v2.py +++ b/mindee/mindee_http/mindee_api_v2.py @@ -4,8 +4,7 @@ import requests from mindee.error.mindee_error import MindeeApiV2Error -from mindee.input import LocalInputSource, UrlInputSource -from mindee.input.inference_parameters import InferenceParameters +from mindee.input import LocalInputSource, UrlInputSource, BaseParameters from mindee.logger import logger from mindee.mindee_http.base_settings import USER_AGENT from mindee.mindee_http.settings_mixin import SettingsMixin @@ -74,34 +73,19 @@ def set_from_env(self) -> None: def req_post_inference_enqueue( self, input_source: Union[LocalInputSource, UrlInputSource], - params: InferenceParameters, + params: BaseParameters, + slug: str, ) -> requests.Response: """ Make an asynchronous request to POST a document for prediction on the V2 API. :param input_source: Input object. :param params: Options for the enqueueing of the document. + :param slug: Slug to use for the enqueueing, defaults to 'inferences'. :return: requests response. """ - data: Dict[str, Union[str, list]] = {"model_id": params.model_id} - url = f"{self.url_root}/inferences/enqueue" - - if params.rag is not None: - data["rag"] = str(params.rag).lower() - if params.raw_text is not None: - data["raw_text"] = str(params.raw_text).lower() - if params.confidence is not None: - data["confidence"] = str(params.confidence).lower() - if params.polygon is not None: - data["polygon"] = str(params.polygon).lower() - if params.webhook_ids and len(params.webhook_ids) > 0: - data["webhook_ids"] = params.webhook_ids - if params.alias and len(params.alias): - data["alias"] = params.alias - if params.text_context and len(params.text_context): - data["text_context"] = params.text_context - if params.data_schema is not None: - data["data_schema"] = str(params.data_schema) + data = params.get_config() + url = f"{self.url_root}/{slug}/enqueue" if isinstance(input_source, LocalInputSource): files = {"file": input_source.read_contents(params.close_file)} @@ -137,14 +121,17 @@ def req_get_job(self, job_id: str) -> requests.Response: allow_redirects=False, ) - def req_get_inference(self, inference_id: str) -> requests.Response: + def req_get_inference(self, inference_id: str, slug: str) -> requests.Response: """ Sends a request matching a given queue_id. Returns either a Job or a Document. :param inference_id: Inference ID, returned by the job request. + :param slug: Slug of the inference, defaults to nothing. """ + + url = f"{self.url_root}/{slug}/{inference_id}" return requests.get( - f"{self.url_root}/inferences/{inference_id}", + url, headers=self.base_headers, timeout=self.request_timeout, allow_redirects=False, diff --git a/mindee/parsing/common/async_predict_response.py b/mindee/parsing/common/async_predict_response.py index e3101633..5d657532 100644 --- a/mindee/parsing/common/async_predict_response.py +++ b/mindee/parsing/common/async_predict_response.py @@ -23,7 +23,7 @@ def __init__( """ Container wrapper for a raw API response. - Inherits and instantiates a normal PredictResponse if the parsing of + Inherits and instantiates a normal PredictResponse if the product of the current queue is both requested and done. :param inference_type: Type of the inference. diff --git a/mindee/parsing/v2/inference.py b/mindee/parsing/v2/inference.py index 86c076c9..477cc41c 100644 --- a/mindee/parsing/v2/inference.py +++ b/mindee/parsing/v2/inference.py @@ -1,28 +1,19 @@ from mindee.parsing.common.string_dict import StringDict +from mindee.v2.parsing.inference import BaseInference from mindee.parsing.v2.inference_active_options import InferenceActiveOptions -from mindee.parsing.v2.inference_file import InferenceFile -from mindee.parsing.v2.inference_model import InferenceModel from mindee.parsing.v2.inference_result import InferenceResult -class Inference: +class Inference(BaseInference): """Inference object for a V2 API return.""" - id: str - """ID of the inference.""" - model: InferenceModel - """Model info for the inference.""" - file: InferenceFile - """File info for the inference.""" result: InferenceResult """Result of the inference.""" active_options: InferenceActiveOptions """Active options for the inference.""" def __init__(self, raw_response: StringDict): - self.id = raw_response["id"] - self.model = InferenceModel(raw_response["model"]) - self.file = InferenceFile(raw_response["file"]) + super().__init__(raw_response) self.result = InferenceResult(raw_response["result"]) self.active_options = InferenceActiveOptions(raw_response["active_options"]) diff --git a/mindee/parsing/v2/inference_response.py b/mindee/parsing/v2/inference_response.py index f1bb71c2..ff056d36 100644 --- a/mindee/parsing/v2/inference_response.py +++ b/mindee/parsing/v2/inference_response.py @@ -1,13 +1,17 @@ from mindee.parsing.common.string_dict import StringDict -from mindee.parsing.v2.common_response import CommonResponse from mindee.parsing.v2.inference import Inference +from mindee.v2.parsing.inference.base_response import ( + BaseResponse, +) -class InferenceResponse(CommonResponse): +class InferenceResponse(BaseResponse): """Represent an inference response from Mindee V2 API.""" inference: Inference """Inference result.""" + _slug: str = "inferences" + """Slug of the inference.""" def __init__(self, raw_response: StringDict) -> None: super().__init__(raw_response) @@ -15,3 +19,8 @@ def __init__(self, raw_response: StringDict) -> None: def __str__(self) -> str: return str(self.inference) + + @classmethod + def get_result_slug(cls) -> str: + """Getter for the inference slug.""" + return cls._slug diff --git a/mindee/v2/__init__.py b/mindee/v2/__init__.py new file mode 100644 index 00000000..136bbc42 --- /dev/null +++ b/mindee/v2/__init__.py @@ -0,0 +1,7 @@ +from mindee.v2.product.split.split_parameters import SplitParameters +from mindee.v2.product.split.split_response import SplitResponse + +__all__ = [ + "SplitResponse", + "SplitParameters", +] diff --git a/mindee/v2/parsing/__init__.py b/mindee/v2/parsing/__init__.py new file mode 100644 index 00000000..3ab40372 --- /dev/null +++ b/mindee/v2/parsing/__init__.py @@ -0,0 +1,7 @@ +from mindee.v2.parsing.inference.base_inference import BaseInference +from mindee.v2.parsing.inference.base_response import BaseResponse + +__all__ = [ + "BaseInference", + "BaseResponse", +] diff --git a/mindee/v2/parsing/inference/__init__.py b/mindee/v2/parsing/inference/__init__.py new file mode 100644 index 00000000..e59b67ae --- /dev/null +++ b/mindee/v2/parsing/inference/__init__.py @@ -0,0 +1,9 @@ +from mindee.v2.parsing.inference.base_inference import BaseInference +from mindee.v2.parsing.inference.base_response import ( + BaseResponse, +) + +__all__ = [ + "BaseInference", + "BaseResponse", +] diff --git a/mindee/v2/parsing/inference/base_inference.py b/mindee/v2/parsing/inference/base_inference.py new file mode 100644 index 00000000..78462f0f --- /dev/null +++ b/mindee/v2/parsing/inference/base_inference.py @@ -0,0 +1,25 @@ +from abc import ABC +from typing import TypeVar + +from mindee.parsing.common.string_dict import StringDict +from mindee.parsing.v2.inference_file import InferenceFile +from mindee.parsing.v2.inference_model import InferenceModel + + +class BaseInference(ABC): + """Base class for V2 inference objects.""" + + model: InferenceModel + """Model info for the inference.""" + file: InferenceFile + """File info for the inference.""" + id: str + """ID of the inference.""" + + def __init__(self, raw_response: StringDict): + self.id = raw_response["id"] + self.model = InferenceModel(raw_response["model"]) + self.file = InferenceFile(raw_response["file"]) + + +TypeBaseInference = TypeVar("TypeBaseInference", bound=BaseInference) diff --git a/mindee/v2/parsing/inference/base_response.py b/mindee/v2/parsing/inference/base_response.py new file mode 100644 index 00000000..55b6deb6 --- /dev/null +++ b/mindee/v2/parsing/inference/base_response.py @@ -0,0 +1,22 @@ +from abc import ABC + +from mindee.v2.parsing.inference.base_inference import BaseInference + +from mindee.parsing.v2.common_response import CommonResponse + + +class BaseResponse(ABC, CommonResponse): + """Base class for V2 inference responses.""" + + inference: BaseInference + """The inference result for a split utility request""" + _slug: str + """Slug of the inference.""" + + def __str__(self) -> str: + return str(self.inference) + + @classmethod + def get_result_slug(cls) -> str: + """Getter for the inference slug.""" + return cls._slug diff --git a/mindee/v2/product/__init__.py b/mindee/v2/product/__init__.py new file mode 100644 index 00000000..136bbc42 --- /dev/null +++ b/mindee/v2/product/__init__.py @@ -0,0 +1,7 @@ +from mindee.v2.product.split.split_parameters import SplitParameters +from mindee.v2.product.split.split_response import SplitResponse + +__all__ = [ + "SplitResponse", + "SplitParameters", +] diff --git a/mindee/v2/product/split/__init__.py b/mindee/v2/product/split/__init__.py new file mode 100644 index 00000000..9284c63e --- /dev/null +++ b/mindee/v2/product/split/__init__.py @@ -0,0 +1,13 @@ +from mindee.v2.product.split.split_inference import SplitInference +from mindee.v2.product.split.split_parameters import SplitParameters +from mindee.v2.product.split.split_response import SplitResponse +from mindee.v2.product.split.split_result import SplitResult +from mindee.v2.product.split.split_range import SplitRange + +__all__ = [ + "SplitInference", + "SplitParameters", + "SplitResponse", + "SplitResult", + "SplitRange", +] diff --git a/mindee/v2/product/split/split_inference.py b/mindee/v2/product/split/split_inference.py new file mode 100644 index 00000000..37aa6edb --- /dev/null +++ b/mindee/v2/product/split/split_inference.py @@ -0,0 +1,19 @@ +from mindee.parsing.common.string_dict import StringDict +from mindee.v2.parsing.inference.base_inference import BaseInference +from mindee.v2.product.split.split_result import SplitResult + + +class SplitInference(BaseInference): + """Split inference result.""" + + result: SplitResult + """Result of a split inference.""" + _slug: str = "split" + """Slug of the endpoint.""" + + def __init__(self, raw_response: StringDict) -> None: + super().__init__(raw_response) + self.result = SplitResult(raw_response["result"]) + + def __str__(self) -> str: + return f"Inference\n#########\n{self.model}\n{self.file}\n{self.result}\n" diff --git a/mindee/v2/product/split/split_parameters.py b/mindee/v2/product/split/split_parameters.py new file mode 100644 index 00000000..191070f6 --- /dev/null +++ b/mindee/v2/product/split/split_parameters.py @@ -0,0 +1,9 @@ +from mindee.input.base_parameters import BaseParameters + + +class SplitParameters(BaseParameters): + """ + Parameters accepted by the split utility v2 endpoint. + """ + + _slug: str = "utilities/split" diff --git a/mindee/v2/product/split/split_range.py b/mindee/v2/product/split/split_range.py new file mode 100644 index 00000000..21a85405 --- /dev/null +++ b/mindee/v2/product/split/split_range.py @@ -0,0 +1,20 @@ +from typing import List + +from mindee.parsing.common.string_dict import StringDict + + +class SplitRange: + """Split inference result.""" + + page_range: List[int] + """Page range of the split inference.""" + document_type: str + """Document type of the split inference.""" + + def __init__(self, server_response: StringDict): + self.page_range = server_response["page_range"] + self.document_type = server_response["document_type"] + + def __str__(self) -> str: + page_range = ",".join([str(page_index) for page_index in self.page_range]) + return f"* :Page Range: {page_range}\n :Document Type: {self.document_type}" diff --git a/mindee/v2/product/split/split_response.py b/mindee/v2/product/split/split_response.py new file mode 100644 index 00000000..dfb3c6d5 --- /dev/null +++ b/mindee/v2/product/split/split_response.py @@ -0,0 +1,17 @@ +from mindee.parsing.common.string_dict import StringDict +from mindee.v2.parsing.inference import BaseResponse +from mindee.v2.product.split.split_inference import SplitInference + + +class SplitResponse(BaseResponse): + """Represent a split inference response from Mindee V2 API.""" + + inference: SplitInference + """Inference object for split inference.""" + + _slug: str = "utilities/split" + """Slug of the inference.""" + + def __init__(self, raw_response: StringDict) -> None: + super().__init__(raw_response) + self.inference = SplitInference(raw_response["inference"]) diff --git a/mindee/v2/product/split/split_result.py b/mindee/v2/product/split/split_result.py new file mode 100644 index 00000000..99c57845 --- /dev/null +++ b/mindee/v2/product/split/split_result.py @@ -0,0 +1,20 @@ +from typing import List + +from mindee.parsing.common.string_dict import StringDict +from mindee.v2.product.split.split_range import SplitRange + + +class SplitResult: + """Split result info.""" + + splits: List[SplitRange] + + def __init__(self, raw_response: StringDict) -> None: + self.splits = [SplitRange(split) for split in raw_response["split"]] + + def __str__(self) -> str: + splits = "\n" + if len(self.splits) > 0: + splits += "\n\n".join([str(split) for split in self.splits]) + out_str = f"Splits\n======{splits}" + return out_str diff --git a/tests/data b/tests/data index 0c51e1d3..c30c33b5 160000 --- a/tests/data +++ b/tests/data @@ -1 +1 @@ -Subproject commit 0c51e1d3e2258404c44280f25f4951ba6fe27324 +Subproject commit c30c33b5217613223398a4b814e9cd96e8255789 diff --git a/tests/utils.py b/tests/utils.py index 252a699c..058e3595 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,6 +17,7 @@ V2_DATA_DIR = ROOT_DATA_DIR / "v2" V2_PRODUCT_DATA_DIR = V2_DATA_DIR / "products" +V2_UTILITIES_DATA_DIR = V2_DATA_DIR / "utilities" def clear_envvars(monkeypatch) -> None: diff --git a/tests/v2/input/test_local_response.py b/tests/v2/input/test_local_response.py index 5ce07fe1..5db8be78 100644 --- a/tests/v2/input/test_local_response.py +++ b/tests/v2/input/test_local_response.py @@ -14,7 +14,7 @@ def file_path() -> Path: def _assert_local_response(local_response): fake_hmac_signing = "ogNjY44MhvKPGTtVsI8zG82JqWQa68woYQH" - signature = "1df388c992d87897fe61dfc56c444c58fc3c7369c31e2b5fd20d867695e93e85" + signature = "f390d9f7f57ac04f47b6309d8a40236b0182610804fc20e91b1f6028aaca07a7" assert local_response._file is not None assert not local_response.is_valid_hmac_signature( diff --git a/tests/v2/parsing/test_split_integration.py b/tests/v2/parsing/test_split_integration.py new file mode 100644 index 00000000..efca8b07 --- /dev/null +++ b/tests/v2/parsing/test_split_integration.py @@ -0,0 +1,39 @@ +import os + +import pytest + +from mindee import ClientV2, PathInput +from mindee.input import SplitParameters +from mindee.v2 import SplitResponse +from tests.utils import V1_PRODUCT_DATA_DIR + + +@pytest.fixture(scope="session") +def split_model_id() -> str: + """Identifier of the Financial Document model, supplied through an env var.""" + return os.getenv("MINDEE_V2_SE_TESTS_SPLIT_MODEL_ID") + + +@pytest.fixture(scope="session") +def v2_client() -> ClientV2: + """ + Real V2 client configured with the user-supplied API key + (or skipped when the key is absent). + """ + api_key = os.getenv("MINDEE_V2_API_KEY") + return ClientV2(api_key) + + +@pytest.mark.integration +@pytest.mark.v2 +def test_split_blank(v2_client: ClientV2, split_model_id: str): + input_source = PathInput( + V1_PRODUCT_DATA_DIR / "invoice_splitter" / "default_sample.pdf" + ) + response = v2_client.enqueue_and_get_result( + SplitResponse, input_source, SplitParameters(split_model_id) + ) # Note: do not use blank_1.pdf for this. + assert response.inference is not None + assert response.inference.file.name == "default_sample.pdf" + assert response.inference.result.splits + assert len(response.inference.result.splits) == 2 diff --git a/tests/v2/parsing/test_split_response.py b/tests/v2/parsing/test_split_response.py new file mode 100644 index 00000000..0ce0d707 --- /dev/null +++ b/tests/v2/parsing/test_split_response.py @@ -0,0 +1,45 @@ +import pytest + +from mindee import LocalResponse +from mindee.v2.product.split.split_range import SplitRange +from mindee.v2.product.split import SplitInference +from mindee.v2.product.split.split_response import SplitResponse +from mindee.v2.product.split.split_result import SplitResult +from tests.utils import V2_UTILITIES_DATA_DIR + + +@pytest.mark.v2 +def test_split_single(): + input_inference = LocalResponse(V2_UTILITIES_DATA_DIR / "split_single.json") + split_response = input_inference.deserialize_response(SplitResponse) + assert isinstance(split_response.inference, SplitInference) + assert split_response.inference.result.splits + assert len(split_response.inference.result.splits[0].page_range) == 2 + assert split_response.inference.result.splits[0].page_range[0] == 0 + assert split_response.inference.result.splits[0].page_range[1] == 0 + assert split_response.inference.result.splits[0].document_type == "receipt" + + +@pytest.mark.v2 +def test_split_multiple(): + input_inference = LocalResponse(V2_UTILITIES_DATA_DIR / "split_multiple.json") + split_response = input_inference.deserialize_response(SplitResponse) + assert isinstance(split_response.inference, SplitInference) + assert isinstance(split_response.inference.result, SplitResult) + assert isinstance(split_response.inference.result.splits[0], SplitRange) + assert len(split_response.inference.result.splits) == 3 + + assert len(split_response.inference.result.splits[0].page_range) == 2 + assert split_response.inference.result.splits[0].page_range[0] == 0 + assert split_response.inference.result.splits[0].page_range[1] == 0 + assert split_response.inference.result.splits[0].document_type == "invoice" + + assert len(split_response.inference.result.splits[1].page_range) == 2 + assert split_response.inference.result.splits[1].page_range[0] == 1 + assert split_response.inference.result.splits[1].page_range[1] == 3 + assert split_response.inference.result.splits[1].document_type == "invoice" + + assert len(split_response.inference.result.splits[2].page_range) == 2 + assert split_response.inference.result.splits[2].page_range[0] == 4 + assert split_response.inference.result.splits[2].page_range[1] == 4 + assert split_response.inference.result.splits[2].document_type == "invoice"