diff --git a/src/crawlee/_types.py b/src/crawlee/_types.py index da11adae5d..a98664d02d 100644 --- a/src/crawlee/_types.py +++ b/src/crawlee/_types.py @@ -2,6 +2,7 @@ import dataclasses from collections.abc import Callable, Iterator, Mapping +from copy import deepcopy from dataclasses import dataclass from typing import TYPE_CHECKING, Annotated, Any, Literal, Protocol, TypedDict, TypeVar, cast, overload @@ -260,12 +261,24 @@ async def get_value(self, key: str, default_value: T | None = None) -> T | None: class RequestHandlerRunResult: """Record of calls to storage-related context helpers.""" - def __init__(self, *, key_value_store_getter: GetKeyValueStoreFunction) -> None: + def __init__( + self, + *, + key_value_store_getter: GetKeyValueStoreFunction, + request: Request, + ) -> None: self._key_value_store_getter = key_value_store_getter self.add_requests_calls = list[AddRequestsKwargs]() self.push_data_calls = list[PushDataFunctionCall]() self.key_value_store_changes = dict[tuple[str | None, str | None, str | None], KeyValueStoreChangeRecords]() + # Isolated copies for handler execution + self._request = deepcopy(request) + + @property + def request(self) -> Request: + return self._request + async def add_requests( self, requests: Sequence[str | Request], @@ -315,6 +328,14 @@ async def get_key_value_store( return self.key_value_store_changes[id, name, alias] + def apply_request_changes(self, target: Request) -> None: + """Apply tracked changes from handler copy to original request.""" + if self.request.user_data != target.user_data: + target.user_data = self.request.user_data + + if self.request.headers != target.headers: + target.headers = self.request.headers + @docs_group('Functions') class AddRequestsFunction(Protocol): diff --git a/src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py b/src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py index c51180e1fc..b3b99e6f59 100644 --- a/src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py +++ b/src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py @@ -290,11 +290,14 @@ async def get_input_state( use_state_function = context.use_state # New result is created and injected to newly created context. This is done to ensure isolation of sub crawlers. - result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store) + result = RequestHandlerRunResult( + key_value_store_getter=self.get_key_value_store, + request=context.request, + ) context_linked_to_result = BasicCrawlingContext( - request=deepcopy(context.request), - session=deepcopy(context.session), - proxy_info=deepcopy(context.proxy_info), + request=result.request, + session=context.session, + proxy_info=context.proxy_info, send_request=context.send_request, add_requests=result.add_requests, push_data=result.push_data, @@ -314,7 +317,7 @@ async def get_input_state( ), logger=self._logger, ) - return SubCrawlerRun(result=result, run_context=context_linked_to_result) + return SubCrawlerRun(result=result) except Exception as e: return SubCrawlerRun(exception=e) @@ -370,8 +373,7 @@ async def _run_request_handler(self, context: BasicCrawlingContext) -> None: self.track_http_only_request_handler_runs() static_run = await self._crawl_one(rendering_type='static', context=context) - if static_run.result and static_run.run_context and self.result_checker(static_run.result): - self._update_context_from_copy(context, static_run.run_context) + if static_run.result and self.result_checker(static_run.result): self._context_result_map[context] = static_run.result return if static_run.exception: @@ -402,7 +404,7 @@ async def _run_request_handler(self, context: BasicCrawlingContext) -> None: if pw_run.exception is not None: raise pw_run.exception - if pw_run.result and pw_run.run_context: + if pw_run.result: if should_detect_rendering_type: detection_result: RenderingType static_run = await self._crawl_one('static', context=context, state=old_state_copy) @@ -414,7 +416,6 @@ async def _run_request_handler(self, context: BasicCrawlingContext) -> None: context.log.debug(f'Detected rendering type {detection_result} for {context.request.url}') self.rendering_type_predictor.store_result(context.request, detection_result) - self._update_context_from_copy(context, pw_run.run_context) self._context_result_map[context] = pw_run.result def pre_navigation_hook( @@ -451,32 +452,8 @@ def track_browser_request_handler_runs(self) -> None: def track_rendering_type_mispredictions(self) -> None: self.statistics.state.rendering_type_mispredictions += 1 - def _update_context_from_copy(self, context: BasicCrawlingContext, context_copy: BasicCrawlingContext) -> None: - """Update mutable fields of `context` from `context_copy`. - - Uses object.__setattr__ to bypass frozen dataclass restrictions, - allowing state synchronization after isolated crawler execution. - """ - updating_attributes = { - 'request': ('headers', 'user_data'), - 'session': ('_user_data', '_usage_count', '_error_score', '_cookies'), - } - - for attr, sub_attrs in updating_attributes.items(): - original_sub_obj = getattr(context, attr) - copy_sub_obj = getattr(context_copy, attr) - - # Check that both sub objects are not None - if original_sub_obj is None or copy_sub_obj is None: - continue - - for sub_attr in sub_attrs: - new_value = getattr(copy_sub_obj, sub_attr) - object.__setattr__(original_sub_obj, sub_attr, new_value) - @dataclass(frozen=True) class SubCrawlerRun: result: RequestHandlerRunResult | None = None exception: Exception | None = None - run_context: BasicCrawlingContext | None = None diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index cc6664cbd6..5cc2993c1c 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -69,6 +69,7 @@ from crawlee.storages import Dataset, KeyValueStore, RequestQueue from ._context_pipeline import ContextPipeline +from ._context_utils import swaped_context from ._logging_utils import ( get_one_line_error_summary_if_possible, reduce_asyncio_timeout_error_to_relevant_traceback_parts, @@ -1321,6 +1322,8 @@ async def _commit_request_handler_result(self, context: BasicCrawlingContext) -> await self._commit_key_value_store_changes(result, get_kvs=self.get_key_value_store) + result.apply_request_changes(target=context.request) + @staticmethod async def _commit_key_value_store_changes( result: RequestHandlerRunResult, get_kvs: GetKeyValueStoreFromRequestHandlerFunction @@ -1386,10 +1389,10 @@ async def __run_task_function(self) -> None: else: session = await self._get_session() proxy_info = await self._get_proxy_info(request, session) - result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store) + result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store, request=request) context = BasicCrawlingContext( - request=request, + request=result.request, session=session, proxy_info=proxy_info, send_request=self._prepare_send_request_function(session, proxy_info), @@ -1404,10 +1407,12 @@ async def __run_task_function(self) -> None: self._statistics.record_request_processing_start(request.unique_key) try: - self._check_request_collision(context.request, context.session) + request.state = RequestState.REQUEST_HANDLER try: - await self._run_request_handler(context=context) + with swaped_context(context, request): + self._check_request_collision(request, session) + await self._run_request_handler(context=context) except asyncio.TimeoutError as e: raise RequestHandlerError(e, context) from e @@ -1417,13 +1422,13 @@ async def __run_task_function(self) -> None: await self._mark_request_as_handled(request) - if context.session and context.session.is_usable: - context.session.mark_good() + if session and session.is_usable: + session.mark_good() self._statistics.record_request_processing_finish(request.unique_key) except RequestCollisionError as request_error: - context.request.no_retry = True + request.no_retry = True await self._handle_request_error(context, request_error) except RequestHandlerError as primary_error: @@ -1438,7 +1443,7 @@ async def __run_task_function(self) -> None: await self._handle_request_error(primary_error.crawling_context, primary_error.wrapped_exception) except SessionError as session_error: - if not context.session: + if not session: raise RuntimeError('SessionError raised in a crawling context without a session') from session_error if self._error_handler: @@ -1448,10 +1453,11 @@ async def __run_task_function(self) -> None: exc_only = ''.join(traceback.format_exception_only(session_error)).strip() self._logger.warning('Encountered "%s", rotating session and retrying...', exc_only) - context.session.retire() + if session: + session.retire() # Increment session rotation count. - context.request.session_rotation_count = (context.request.session_rotation_count or 0) + 1 + request.session_rotation_count = (request.session_rotation_count or 0) + 1 await request_manager.reclaim_request(request) await self._statistics.error_tracker_retry.add(error=session_error, context=context) diff --git a/src/crawlee/crawlers/_basic/_context_utils.py b/src/crawlee/crawlers/_basic/_context_utils.py new file mode 100644 index 0000000000..a28e5582d6 --- /dev/null +++ b/src/crawlee/crawlers/_basic/_context_utils.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterator + + from crawlee._request import Request + + from ._basic_crawling_context import BasicCrawlingContext + + +@contextmanager +def swaped_context( + context: BasicCrawlingContext, + request: Request, +) -> Iterator[None]: + """Replace context's isolated copies with originals after handler execution.""" + try: + yield + finally: + # Restore original context state to avoid side effects between different handlers. + object.__setattr__(context, 'request', request) diff --git a/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py b/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py index b2e3a41853..c2d0f153b1 100644 --- a/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py +++ b/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py @@ -802,6 +802,7 @@ async def request_handler(context: AdaptivePlaywrightCrawlingContext) -> None: assert session is not None assert check_request is not None + assert session.user_data.get('session_state') is True # Check that request user data was updated in the handler and only onse. assert check_request.user_data.get('request_state') == ['initial', 'handler'] diff --git a/tests/unit/crawlers/_basic/test_basic_crawler.py b/tests/unit/crawlers/_basic/test_basic_crawler.py index 6bc1b4f479..d4ac09ff85 100644 --- a/tests/unit/crawlers/_basic/test_basic_crawler.py +++ b/tests/unit/crawlers/_basic/test_basic_crawler.py @@ -1825,6 +1825,29 @@ async def handler(_: BasicCrawlingContext) -> None: await crawler_task +async def test_protect_request_in_run_handlers() -> None: + """Test that request in crawling context are protected in run handlers.""" + request_queue = await RequestQueue.open(name='state-test') + + request = Request.from_url('https://test.url/', user_data={'request_state': ['initial']}) + + crawler = BasicCrawler(request_manager=request_queue, max_request_retries=0) + + @crawler.router.default_handler + async def handler(context: BasicCrawlingContext) -> None: + if isinstance(context.request.user_data['request_state'], list): + context.request.user_data['request_state'].append('modified') + raise ValueError('Simulated error after modifying request') + + await crawler.run([request]) + + check_request = await request_queue.get_request(request.unique_key) + assert check_request is not None + assert check_request.user_data['request_state'] == ['initial'] + + await request_queue.drop() + + async def test_new_request_error_handler() -> None: """Test that error in new_request_handler is handled properly.""" queue = await RequestQueue.open()