Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,30 @@ async def _sendfile_fallback(self, transp, file, offset, count):
file.seek(offset + total_sent)
await proto.restore()

def _transfer_buffered_data_to_ssl(self, protocol, ssl_protocol):
"""Transfer buffered data from StreamReader to SSL incoming BIO.

When using start_tls() mid-connection (e.g., after reading a
PROXY protocol header), any data already buffered in the
StreamReader would be lost. This transfers that data to the
SSL layer so the handshake can proceed.

Note: This only works with StreamReaderProtocol (used by the
streams API). Custom Protocol implementations that buffer data
must handle this manually before calling start_tls().
"""
if not hasattr(protocol, '_stream_reader'):
return

stream_reader = protocol._stream_reader
if stream_reader is None:
return

buffer = stream_reader._buffer
if buffer:
ssl_protocol._incoming.write(buffer)
buffer.clear()

async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False,
server_hostname=None,
Expand Down Expand Up @@ -1341,6 +1365,8 @@ async def start_tls(self, transport, protocol, sslcontext, *,
ssl_shutdown_timeout=ssl_shutdown_timeout,
call_connection_made=False)

self._transfer_buffered_data_to_ssl(protocol, ssl_protocol)

# Pause early so that "ssl_protocol.data_received()" doesn't
# have a chance to get called before "ssl_protocol.connection_made()".
transport.pause_reading()
Expand Down
159 changes: 159 additions & 0 deletions Lib/test/test_asyncio/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,165 @@ async def client(addr):
self.assertEqual(msg1, b"hello world 1!\n")
self.assertEqual(msg2, b"hello world 2!\n")

def _run_test_start_tls_behind_proxy(self, send_combined):
"""Test start_tls() when TLS ClientHello arrives with PROXY header.

This simulates HAProxy with send-proxy, where the PROXY protocol
header and TLS handshake data may arrive in the same TCP segment.
Without the fix, buffered TLS data would be lost after start_tls().
"""

def reverse_message(data):
return data.strip()[::-1] + b'\n'

test_message = b"hello world\n"
expected_response = reverse_message(test_message)

class TCPProxyServer:
"""A simple TCP proxy server that adds a PROXY protocol header
before forwarding data to the target server."""

PROXY_LINE = b"PROXY TCP4 127.0.0.1 127.0.0.1 54321 443\r\n"

def __init__(self, loop, target_host, target_port):
self.loop = loop
self.target_host = target_host
self.target_port = target_port
self.server = None

async def _pipe(self, reader, writer):
try:
while True:
data = await reader.read(4096)
if not data:
break
writer.write(data)
await writer.drain()
finally:
writer.close()
await writer.wait_closed()

async def handle_client(self, client_reader, client_writer):
# Connecting to the target server
remote_reader, remote_writer = await asyncio.open_connection(
self.target_host, self.target_port)

# Reading data from the client (TLS ClientHello)
tls_data = await client_reader.read(4096)

if send_combined:
# send everything together: PROXY + TLS data
remote_writer.write(self.PROXY_LINE + tls_data)
await remote_writer.drain()
else:
# send TLS data after the PROXY line
remote_writer.write(self.PROXY_LINE)
await remote_writer.drain()
await asyncio.sleep(0.01)
remote_writer.write(tls_data)
await remote_writer.drain()

await asyncio.gather(
self._pipe(client_reader, remote_writer),
self._pipe(remote_reader, client_writer),
)

def start(self):
sock = socket.create_server(('127.0.0.1', 0))
self.server = self.loop.run_until_complete(
asyncio.start_server(self.handle_client, sock=sock))
return sock.getsockname()

def stop(self):
if self.server:
self.server.close()
self.loop.run_until_complete(self.server.wait_closed())
self.server = None

class ServerWithSendProxySupport:
"""A server that supports the PROXY protocol and starts TLS
after receiving the PROXY header."""

def __init__(self, test_case, loop):
self.test = test_case
self.server = None
self.loop = loop

async def handle_client(self, client_reader, client_writer):
proxy_line = await client_reader.readline()
self.test.assertEqual(proxy_line, TCPProxyServer.PROXY_LINE)

# Now we can start TLS
self.test.assertIsNone(
client_writer.get_extra_info('sslcontext'))
await client_writer.start_tls(
test_utils.simple_server_sslcontext()
)
self.test.assertIsNotNone(
client_writer.get_extra_info('sslcontext'))

data = await client_reader.readline()
client_writer.write(reverse_message(data))
await client_writer.drain()
client_writer.close()
await client_writer.wait_closed()

def start(self):
sock = socket.create_server(('127.0.0.1', 0))
self.server = self.loop.run_until_complete(
asyncio.start_server(self.handle_client,
sock=sock))
return sock.getsockname()

def stop(self):
if self.server is not None:
self.server.close()
self.loop.run_until_complete(self.server.wait_closed())
self.server = None

async def client(addr, test_case):
reader, writer = await asyncio.open_connection(*addr)

test_case.assertIsNone(writer.get_extra_info('sslcontext'))
await writer.start_tls(test_utils.simple_client_sslcontext())
test_case.assertIsNotNone(writer.get_extra_info('sslcontext'))

writer.write(test_message)
await writer.drain()
msgback = await reader.readline()
writer.close()
await writer.wait_closed()
return msgback

messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))

server = ServerWithSendProxySupport(self, self.loop)
server_addr = server.start()

proxy = TCPProxyServer(self.loop, *server_addr)
proxy_addr = proxy.start()

msg = self.loop.run_until_complete(
asyncio.wait_for(client(proxy_addr, self), timeout=5.0)
)

proxy.stop()
server.stop()

self.assertEqual(messages, [])
self.assertEqual(msg, expected_response)

@unittest.skipIf(ssl is None, 'No ssl module')
def test_start_tls_behind_proxy_send_combined(self):
# Test with sending PROXY header and TLS data in one packet
self._run_test_start_tls_behind_proxy(send_combined=True)

@unittest.skipIf(ssl is None, 'No ssl module')
def test_start_tls_behind_proxy_send_separate(self):
# Test with sending PROXY header and TLS data in separate packets
self._run_test_start_tls_behind_proxy(send_combined=False)

def test_streamreader_constructor_without_loop(self):
with self.assertRaisesRegex(RuntimeError, 'no current event loop'):
asyncio.StreamReader()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Fix :meth:`asyncio.StreamWriter.start_tls` to transfer buffered data from
:class:`~asyncio.StreamReader` to the SSL layer, preventing data loss when
upgrading a connection to TLS mid-stream (e.g., when implementing PROXY
protocol support).
Loading