Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/crawlee/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dataclasses
from collections.abc import Callable, Iterator, Mapping
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Annotated, Any, Literal, Protocol, TypedDict, TypeVar, cast, overload

Expand Down Expand Up @@ -260,12 +261,24 @@ async def get_value(self, key: str, default_value: T | None = None) -> T | None:
class RequestHandlerRunResult:
"""Record of calls to storage-related context helpers."""

def __init__(self, *, key_value_store_getter: GetKeyValueStoreFunction) -> None:
def __init__(
self,
*,
key_value_store_getter: GetKeyValueStoreFunction,
request: Request,
) -> None:
self._key_value_store_getter = key_value_store_getter
self.add_requests_calls = list[AddRequestsKwargs]()
self.push_data_calls = list[PushDataFunctionCall]()
self.key_value_store_changes = dict[tuple[str | None, str | None, str | None], KeyValueStoreChangeRecords]()

# Isolated copies for handler execution
self._request = deepcopy(request)

@property
def request(self) -> Request:
return self._request

async def add_requests(
self,
requests: Sequence[str | Request],
Expand Down Expand Up @@ -315,6 +328,14 @@ async def get_key_value_store(

return self.key_value_store_changes[id, name, alias]

def apply_request_changes(self, target: Request) -> None:
"""Apply tracked changes from handler copy to original request."""
if self.request.user_data != target.user_data:
target.user_data = self.request.user_data

if self.request.headers != target.headers:
target.headers = self.request.headers


@docs_group('Functions')
class AddRequestsFunction(Protocol):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,14 @@ async def get_input_state(
use_state_function = context.use_state

# New result is created and injected to newly created context. This is done to ensure isolation of sub crawlers.
result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store)
result = RequestHandlerRunResult(
key_value_store_getter=self.get_key_value_store,
request=context.request,
)
context_linked_to_result = BasicCrawlingContext(
request=deepcopy(context.request),
session=deepcopy(context.session),
proxy_info=deepcopy(context.proxy_info),
request=result.request,
session=context.session,
proxy_info=context.proxy_info,
send_request=context.send_request,
add_requests=result.add_requests,
push_data=result.push_data,
Expand All @@ -314,7 +317,7 @@ async def get_input_state(
),
logger=self._logger,
)
return SubCrawlerRun(result=result, run_context=context_linked_to_result)
return SubCrawlerRun(result=result)
except Exception as e:
return SubCrawlerRun(exception=e)

Expand Down Expand Up @@ -370,8 +373,7 @@ async def _run_request_handler(self, context: BasicCrawlingContext) -> None:
self.track_http_only_request_handler_runs()

static_run = await self._crawl_one(rendering_type='static', context=context)
if static_run.result and static_run.run_context and self.result_checker(static_run.result):
self._update_context_from_copy(context, static_run.run_context)
if static_run.result and self.result_checker(static_run.result):
self._context_result_map[context] = static_run.result
return
if static_run.exception:
Expand Down Expand Up @@ -402,7 +404,7 @@ async def _run_request_handler(self, context: BasicCrawlingContext) -> None:
if pw_run.exception is not None:
raise pw_run.exception

if pw_run.result and pw_run.run_context:
if pw_run.result:
if should_detect_rendering_type:
detection_result: RenderingType
static_run = await self._crawl_one('static', context=context, state=old_state_copy)
Expand All @@ -414,7 +416,6 @@ async def _run_request_handler(self, context: BasicCrawlingContext) -> None:
context.log.debug(f'Detected rendering type {detection_result} for {context.request.url}')
self.rendering_type_predictor.store_result(context.request, detection_result)

self._update_context_from_copy(context, pw_run.run_context)
self._context_result_map[context] = pw_run.result

def pre_navigation_hook(
Expand Down Expand Up @@ -451,32 +452,8 @@ def track_browser_request_handler_runs(self) -> None:
def track_rendering_type_mispredictions(self) -> None:
self.statistics.state.rendering_type_mispredictions += 1

def _update_context_from_copy(self, context: BasicCrawlingContext, context_copy: BasicCrawlingContext) -> None:
"""Update mutable fields of `context` from `context_copy`.

Uses object.__setattr__ to bypass frozen dataclass restrictions,
allowing state synchronization after isolated crawler execution.
"""
updating_attributes = {
'request': ('headers', 'user_data'),
'session': ('_user_data', '_usage_count', '_error_score', '_cookies'),
}

for attr, sub_attrs in updating_attributes.items():
original_sub_obj = getattr(context, attr)
copy_sub_obj = getattr(context_copy, attr)

# Check that both sub objects are not None
if original_sub_obj is None or copy_sub_obj is None:
continue

for sub_attr in sub_attrs:
new_value = getattr(copy_sub_obj, sub_attr)
object.__setattr__(original_sub_obj, sub_attr, new_value)


@dataclass(frozen=True)
class SubCrawlerRun:
result: RequestHandlerRunResult | None = None
exception: Exception | None = None
run_context: BasicCrawlingContext | None = None
26 changes: 16 additions & 10 deletions src/crawlee/crawlers/_basic/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from crawlee.storages import Dataset, KeyValueStore, RequestQueue

from ._context_pipeline import ContextPipeline
from ._context_utils import swaped_context
from ._logging_utils import (
get_one_line_error_summary_if_possible,
reduce_asyncio_timeout_error_to_relevant_traceback_parts,
Expand Down Expand Up @@ -1321,6 +1322,8 @@ async def _commit_request_handler_result(self, context: BasicCrawlingContext) ->

await self._commit_key_value_store_changes(result, get_kvs=self.get_key_value_store)

result.apply_request_changes(target=context.request)

@staticmethod
async def _commit_key_value_store_changes(
result: RequestHandlerRunResult, get_kvs: GetKeyValueStoreFromRequestHandlerFunction
Expand Down Expand Up @@ -1386,10 +1389,10 @@ async def __run_task_function(self) -> None:
else:
session = await self._get_session()
proxy_info = await self._get_proxy_info(request, session)
result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store)
result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store, request=request)

context = BasicCrawlingContext(
request=request,
request=result.request,
session=session,
proxy_info=proxy_info,
send_request=self._prepare_send_request_function(session, proxy_info),
Expand All @@ -1404,10 +1407,12 @@ async def __run_task_function(self) -> None:
self._statistics.record_request_processing_start(request.unique_key)

try:
self._check_request_collision(context.request, context.session)
request.state = RequestState.REQUEST_HANDLER

try:
await self._run_request_handler(context=context)
with swaped_context(context, request):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is somewhat confusing. It sets the original Request object to the readonly property context.request. It is kind of breaking our own constraint.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was done so that protected request would only be executed within _run_request_handler.

Outside of _run_request_handler, we always work with the original Request, since we work with fields such as state, retry_count, no_retry, and others. Their state is not synchronized, since we expect that the user will not change it in user handlers.

But yes, it is not intuitive and makes the code more difficult to read.

self._check_request_collision(request, session)
await self._run_request_handler(context=context)
except asyncio.TimeoutError as e:
raise RequestHandlerError(e, context) from e

Expand All @@ -1417,13 +1422,13 @@ async def __run_task_function(self) -> None:

await self._mark_request_as_handled(request)

if context.session and context.session.is_usable:
context.session.mark_good()
if session and session.is_usable:
session.mark_good()

self._statistics.record_request_processing_finish(request.unique_key)

except RequestCollisionError as request_error:
context.request.no_retry = True
request.no_retry = True
await self._handle_request_error(context, request_error)

except RequestHandlerError as primary_error:
Expand All @@ -1438,7 +1443,7 @@ async def __run_task_function(self) -> None:
await self._handle_request_error(primary_error.crawling_context, primary_error.wrapped_exception)

except SessionError as session_error:
if not context.session:
if not session:
raise RuntimeError('SessionError raised in a crawling context without a session') from session_error

if self._error_handler:
Expand All @@ -1448,10 +1453,11 @@ async def __run_task_function(self) -> None:
exc_only = ''.join(traceback.format_exception_only(session_error)).strip()
self._logger.warning('Encountered "%s", rotating session and retrying...', exc_only)

context.session.retire()
if session:
session.retire()

# Increment session rotation count.
context.request.session_rotation_count = (context.request.session_rotation_count or 0) + 1
request.session_rotation_count = (request.session_rotation_count or 0) + 1

await request_manager.reclaim_request(request)
await self._statistics.error_tracker_retry.add(error=session_error, context=context)
Expand Down
24 changes: 24 additions & 0 deletions src/crawlee/crawlers/_basic/_context_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Iterator

from crawlee._request import Request

from ._basic_crawling_context import BasicCrawlingContext


@contextmanager
def swaped_context(
context: BasicCrawlingContext,
request: Request,
) -> Iterator[None]:
"""Replace context's isolated copies with originals after handler execution."""
try:
yield
finally:
# Restore original context state to avoid side effects between different handlers.
object.__setattr__(context, 'request', request)
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,7 @@ async def request_handler(context: AdaptivePlaywrightCrawlingContext) -> None:

assert session is not None
assert check_request is not None

assert session.user_data.get('session_state') is True
# Check that request user data was updated in the handler and only onse.
assert check_request.user_data.get('request_state') == ['initial', 'handler']
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/crawlers/_basic/test_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1825,6 +1825,29 @@ async def handler(_: BasicCrawlingContext) -> None:
await crawler_task


async def test_protect_request_in_run_handlers() -> None:
"""Test that request in crawling context are protected in run handlers."""
request_queue = await RequestQueue.open(name='state-test')

request = Request.from_url('https://test.url/', user_data={'request_state': ['initial']})

crawler = BasicCrawler(request_manager=request_queue, max_request_retries=0)

@crawler.router.default_handler
async def handler(context: BasicCrawlingContext) -> None:
if isinstance(context.request.user_data['request_state'], list):
context.request.user_data['request_state'].append('modified')
raise ValueError('Simulated error after modifying request')

await crawler.run([request])

check_request = await request_queue.get_request(request.unique_key)
assert check_request is not None
assert check_request.user_data['request_state'] == ['initial']

await request_queue.drop()


async def test_new_request_error_handler() -> None:
"""Test that error in new_request_handler is handled properly."""
queue = await RequestQueue.open()
Expand Down
Loading