Skip to content

Commit f39e8f0

Browse files
committed
Fixes to handle retries for WantWriteError and WantReadError in SSL
As discussed in #245
1 parent 7eced3d commit f39e8f0

File tree

4 files changed

+79
-4
lines changed

4 files changed

+79
-4
lines changed

cheroot/makefile.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# prefer slower Python-based io module
44
import _pyio as io
55
import socket
6+
import time
7+
8+
from OpenSSL import SSL
69

710

811
# Write only 16K at a time to sockets
@@ -31,7 +34,15 @@ def _flush_unlocked(self):
3134
# so perhaps we should conditionally wrap this for perf?
3235
n = self.raw.write(bytes(self._write_buf))
3336
except io.BlockingIOError as e:
34-
n = e.characters_written
37+
n = e.characters_writteni
38+
except (
39+
SSL.WantReadError,
40+
SSL.WantWriteError,
41+
SSL.WantX509LookupError,
42+
):
43+
# these errors require retries with the same data
44+
# if some data has already been written
45+
n = 0
3546
del self._write_buf[:n]
3647

3748

@@ -45,9 +56,15 @@ def __init__(self, sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE):
4556

4657
def read(self, *args, **kwargs):
4758
"""Capture bytes read."""
48-
val = super().read(*args, **kwargs)
49-
self.bytes_read += len(val)
50-
return val
59+
while True:
60+
try:
61+
val = super().read(*args, **kwargs)
62+
self.bytes_read += len(val)
63+
return val
64+
except SSL.WantReadError:
65+
time.sleep(0.1) # allow some retry delay
66+
except SSL.WantWriteError:
67+
time.sleep(0.1)
5168

5269
def has_data(self):
5370
"""Return true if there is buffered data to read."""

cheroot/server.pyi

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
from typing import Any
22

3+
__all__ = (
4+
'ChunkedRFile',
5+
'DropUnderscoreHeaderReader',
6+
'Gateway',
7+
'HTTPConnection',
8+
'HTTPRequest',
9+
'HTTPServer',
10+
'HeaderReader',
11+
'KnownLengthRFile',
12+
'SizeCheckWrapper',
13+
'get_ssl_adapter_class',
14+
)
15+
316
class HeaderReader:
417
def __call__(self, rfile, hdict: Any | None = ...): ...
518

cheroot/test/test_ssl.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import requests
1818
import trustme
1919

20+
from cheroot.makefile import BufferedWriter
21+
2022
from .._compat import (
2123
IS_ABOVE_OPENSSL10,
2224
IS_ABOVE_OPENSSL31,
@@ -625,6 +627,47 @@ def test_ssl_env( # noqa: C901 # FIXME
625627
)
626628

627629

630+
@pytest.fixture
631+
def ssl_writer(mocker):
632+
"""Return a BufferedWriter instance with a mocked raw socket."""
633+
mock_raw = mocker.Mock()
634+
mock_raw.closed = False
635+
writer = BufferedWriter(mock_raw)
636+
637+
writer.mock_raw = mock_raw
638+
639+
return writer
640+
641+
642+
def test_want_write_error_retry(ssl_writer):
643+
"""Test that WantWriteError causes retry with same data."""
644+
test_data = b'hello world'
645+
646+
# Access the mock object via the attribute we attached in the fixture
647+
ssl_writer.mock_raw.write.side_effect = [
648+
OpenSSL.SSL.WantWriteError(),
649+
len(test_data),
650+
]
651+
652+
bytes_written = ssl_writer.write(test_data)
653+
assert bytes_written == len(test_data)
654+
assert ssl_writer.mock_raw.write.call_count == 2
655+
656+
657+
def test_want_read_error_retry(ssl_writer):
658+
"""Test that WantReadError causes retry with same data."""
659+
test_data = b'test data'
660+
661+
# Access the mock object via the attribute we attached in the fixture
662+
ssl_writer.mock_raw.write.side_effect = [
663+
OpenSSL.SSL.WantReadError(),
664+
len(test_data),
665+
]
666+
667+
bytes_written = ssl_writer.write(test_data)
668+
assert bytes_written == len(test_data)
669+
670+
628671
@pytest.mark.parametrize(
629672
'ip_addr',
630673
(

cheroot/workers/threadpool.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import threading
22
from typing import Any
33

4+
__all__ = ('ThreadPool', 'WorkerThread')
5+
46
class TrueyZero:
57
def __add__(self, other): ...
68
def __radd__(self, other): ...

0 commit comments

Comments
 (0)