Skip to content

Commit b6f879c

Browse files
Feat/intercept (#157)
* feat: intercept * fix fetch domain + changelogs * using example file for intercept test --------- Co-authored-by: Stephan Lensky <8302875+stephanlensky@users.noreply.github.com>
1 parent 2558388 commit b6f879c

File tree

6 files changed

+232
-1
lines changed

6 files changed

+232
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
### Added
1313

14+
- Added `Tab.intercept` @nathanfallet
15+
1416
### Changed
1517

1618
### Removed

tests/core/test_tab.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import zendriver as zd
66
from tests.sample_data import sample_file
7+
from zendriver.cdp.fetch import RequestStage
8+
from zendriver.cdp.network import ResourceType
79

810

911
async def test_set_user_agent_sets_navigator_values(browser: zd.Browser):
@@ -219,3 +221,20 @@ async def test_expect_download(browser: zd.Browser):
219221
download = await asyncio.wait_for(download_ex.value, timeout=3)
220222
assert type(download) is zd.cdp.browser.DownloadWillBegin
221223
assert download.url is not None
224+
225+
226+
async def test_intercept(browser: zd.Browser):
227+
tab = browser.main_tab
228+
229+
async with tab.intercept(
230+
"*/user-data.json",
231+
RequestStage.RESPONSE,
232+
ResourceType.XHR,
233+
) as interception:
234+
await tab.get(sample_file("profile.html"))
235+
body, _ = await interception.response_body
236+
await interception.continue_request()
237+
238+
assert body is not None
239+
# original_response = loads(body)
240+
# assert original_response["name"] == "Zendriver"

tests/sample_data/profile.html

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
<!DOCTYPE html>
2+
<html lang="en">
3+
<head>
4+
<meta charset="UTF-8">
5+
<title>Fetch Example</title>
6+
<script>
7+
window.onload = function () {
8+
fetch('https://cdpdriver.github.io/examples/user-data.json')
9+
.then(response => response.json())
10+
.then(data => {
11+
document.getElementById('result').textContent = JSON.stringify(data, null, 2);
12+
})
13+
.catch(error => {
14+
document.getElementById('result').textContent = 'Error: ' + error;
15+
});
16+
};
17+
</script>
18+
</head>
19+
<body>
20+
<h1>Fetch Result</h1>
21+
<pre id="result">Loading...</pre>
22+
</body>
23+
</html>

zendriver/core/connection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ class Connection(metaclass=CantTouchThis):
188188
attached: bool
189189
websocket: websockets.asyncio.client.ClientConnection | None = None
190190
_target: cdp.target.TargetInfo | None
191+
_current_id_mutex: asyncio.Lock = asyncio.Lock()
191192

192193
def __init__(
193194
self,
@@ -468,7 +469,8 @@ async def send(
468469
tx.connection = self
469470
if not self.mapper:
470471
self.__count__ = itertools.count(0)
471-
tx.id = next(self.__count__)
472+
async with self._current_id_mutex:
473+
tx.id = next(self.__count__)
472474
self.mapper.update({tx.id: tx})
473475
if not _is_update:
474476
await self._register_handlers()

zendriver/core/intercept.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import asyncio
2+
import typing
3+
4+
from zendriver import cdp
5+
from zendriver.cdp.fetch import HeaderEntry, RequestStage, RequestPattern
6+
from zendriver.cdp.network import ResourceType
7+
from zendriver.core.connection import Connection
8+
9+
10+
class BaseFetchInterception:
11+
"""
12+
Base class to wait for a Fetch response matching a URL pattern.
13+
Use this to collect and decode a paused fetch response, while keeping
14+
the use block clean and returning its own result.
15+
16+
:param tab: The Tab instance to monitor.
17+
:param url_pattern: The URL pattern to match requests and responses.
18+
:param request_stage: The stage of the fetch request to intercept (e.g., request or response).
19+
:param resource_type: The type of resource to intercept (e.g., document, script, etc.).
20+
"""
21+
22+
def __init__(
23+
self,
24+
tab: Connection,
25+
url_pattern: str,
26+
request_stage: RequestStage,
27+
resource_type: ResourceType,
28+
):
29+
self.tab = tab
30+
self.url_pattern = url_pattern
31+
self.request_stage = request_stage
32+
self.resource_type = resource_type
33+
self.response_future: asyncio.Future[cdp.fetch.RequestPaused] = asyncio.Future()
34+
35+
async def _response_handler(self, event: cdp.fetch.RequestPaused):
36+
"""
37+
Internal handler for response events.
38+
:param event: The response event.
39+
:type event: cdp.fetch.RequestPaused
40+
"""
41+
self._remove_response_handler()
42+
self.response_future.set_result(event)
43+
44+
def _remove_response_handler(self):
45+
"""
46+
Remove the response event handler.
47+
"""
48+
self.tab.remove_handlers(cdp.fetch.RequestPaused, self._response_handler)
49+
50+
async def __aenter__(self):
51+
"""
52+
Enter the context manager, adding request and response handlers.
53+
"""
54+
await self.tab.send(
55+
cdp.fetch.enable(
56+
[
57+
RequestPattern(
58+
url_pattern=self.url_pattern,
59+
request_stage=self.request_stage,
60+
resource_type=self.resource_type,
61+
)
62+
]
63+
)
64+
)
65+
self.tab.enabled_domains.append(
66+
cdp.fetch
67+
) # trick to avoid another `fetch.enable` call by _register_handlers
68+
self.tab.add_handler(cdp.fetch.RequestPaused, self._response_handler)
69+
return self
70+
71+
async def __aexit__(self, *args):
72+
"""
73+
Exit the context manager, removing request and response handlers.
74+
"""
75+
self._remove_response_handler()
76+
await self.tab.send(cdp.fetch.disable())
77+
78+
@property
79+
async def request(self):
80+
"""
81+
Get the matched request.
82+
:return: The matched request.
83+
:rtype: cdp.network.request
84+
"""
85+
return (await self.response_future).request
86+
87+
@property
88+
async def response_body(self) -> tuple[str, bool]:
89+
"""
90+
Get the body of the matched response.
91+
:return: The response body.
92+
:rtype: str
93+
"""
94+
request_id = (await self.response_future).request_id
95+
body = await self.tab.send(cdp.fetch.get_response_body(request_id=request_id))
96+
return body
97+
98+
async def fail_request(self, error_reason: cdp.network.ErrorReason) -> None:
99+
request_id = (await self.response_future).request_id
100+
await self.tab.send(
101+
cdp.fetch.fail_request(request_id=request_id, error_reason=error_reason)
102+
)
103+
104+
async def continue_request(
105+
self,
106+
url: typing.Optional[str] = None,
107+
method: typing.Optional[str] = None,
108+
post_data: typing.Optional[str] = None,
109+
headers: typing.Optional[typing.List[HeaderEntry]] = None,
110+
intercept_response: typing.Optional[bool] = None,
111+
) -> None:
112+
request_id = (await self.response_future).request_id
113+
await self.tab.send(
114+
cdp.fetch.continue_request(
115+
request_id=request_id,
116+
url=url,
117+
method=method,
118+
post_data=post_data,
119+
headers=headers,
120+
intercept_response=intercept_response,
121+
)
122+
)
123+
124+
async def fulfill_request(
125+
self,
126+
response_code: int,
127+
response_headers: typing.Optional[typing.List[HeaderEntry]] = None,
128+
binary_response_headers: typing.Optional[str] = None,
129+
body: typing.Optional[str] = None,
130+
response_phrase: typing.Optional[str] = None,
131+
) -> None:
132+
request_id = (await self.response_future).request_id
133+
await self.tab.send(
134+
cdp.fetch.fulfill_request(
135+
request_id=request_id,
136+
response_code=response_code,
137+
response_headers=response_headers,
138+
binary_response_headers=binary_response_headers,
139+
body=body,
140+
response_phrase=response_phrase,
141+
)
142+
)
143+
144+
async def continue_response(
145+
self,
146+
response_code: typing.Optional[int] = None,
147+
response_phrase: typing.Optional[str] = None,
148+
response_headers: typing.Optional[typing.List[HeaderEntry]] = None,
149+
binary_response_headers: typing.Optional[str] = None,
150+
) -> None:
151+
request_id = (await self.response_future).request_id
152+
await self.tab.send(
153+
cdp.fetch.continue_response(
154+
request_id=request_id,
155+
response_code=response_code,
156+
response_phrase=response_phrase,
157+
response_headers=response_headers,
158+
binary_response_headers=binary_response_headers,
159+
)
160+
)

zendriver/core/tab.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
import webbrowser
1414
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
1515

16+
from .intercept import BaseFetchInterception
1617
from .. import cdp
1718
from . import element, util
1819
from .config import PathLike
1920
from .connection import Connection, ProtocolException
2021
from .expect import DownloadExpectation, RequestExpectation, ResponseExpectation
22+
from ..cdp.fetch import RequestStage
23+
from ..cdp.network import ResourceType
2124

2225
if TYPE_CHECKING:
2326
from .browser import Browser
@@ -1233,6 +1236,28 @@ def expect_download(self) -> DownloadExpectation:
12331236
"""
12341237
return DownloadExpectation(self)
12351238

1239+
def intercept(
1240+
self,
1241+
url_pattern: str,
1242+
request_stage: RequestStage,
1243+
resource_type: ResourceType,
1244+
) -> BaseFetchInterception:
1245+
"""
1246+
Sets up interception for network requests matching a URL pattern, request stage, and resource type.
1247+
1248+
:param url_pattern: URL string or regex pattern to match requests.
1249+
:type url_pattern: Union[str, re.Pattern[str]]
1250+
:param request_stage: Stage of the request to intercept (e.g., request, response).
1251+
:type request_stage: RequestStage
1252+
:param resource_type: Type of resource (e.g., Document, Script, Image).
1253+
:type resource_type: ResourceType
1254+
:return: A BaseFetchInterception instance for further configuration or awaiting intercepted requests.
1255+
:rtype: BaseFetchInterception
1256+
1257+
Use this to block, modify, or inspect network traffic for specific resources during browser automation.
1258+
"""
1259+
return BaseFetchInterception(self, url_pattern, request_stage, resource_type)
1260+
12361261
async def download_file(self, url: str, filename: Optional[PathLike] = None):
12371262
"""
12381263
downloads file by given url.

0 commit comments

Comments
 (0)