◐ Shell
clean mode source ↗

bpo-34975: Add start_tls() method to streams API by icgood · Pull Request #13143 · python/cpython

Expand Up @@ -759,6 +759,72 @@ async def client(path):
self.assertEqual(messages, [])
@unittest.skipIf(ssl is None, 'No ssl module') def test_start_tls(self):
class MyServer:
def __init__(self, loop): self.server = None self.loop = loop
async def handle_client(self, client_reader, client_writer): data1 = await client_reader.readline() client_writer.write(data1) await client_writer.drain() assert client_writer.get_extra_info('sslcontext') is None await client_writer.start_tls( test_utils.simple_server_sslcontext()) assert client_writer.get_extra_info('sslcontext') is not None data2 = await client_reader.readline() client_writer.write(data2) await client_writer.drain() client_writer.close() await client_writer.wait_closed()
def start(self): sock = socket.create_server(('127.0.0.1', 0)) self.server = self.loop.run_until_complete( asyncio.start_server(self.handle_client, sock=sock, loop=self.loop)) return sock.getsockname()
def stop(self): if self.server is not None: self.server.close() self.loop.run_until_complete(self.server.wait_closed()) self.server = None
async def client(addr): reader, writer = await asyncio.open_connection( *addr, loop=self.loop) writer.write(b"hello world 1!\n") await writer.drain() msgback1 = await reader.readline() assert writer.get_extra_info('sslcontext') is None await writer.start_tls(test_utils.simple_client_sslcontext()) assert writer.get_extra_info('sslcontext') is not None writer.write(b"hello world 2!\n") await writer.drain() msgback2 = await reader.readline() writer.close() await writer.wait_closed() return msgback1, msgback2
messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
server = MyServer(self.loop) addr = server.start() msg1, msg2 = self.loop.run_until_complete( asyncio.Task(client(addr), loop=self.loop)) server.stop()
self.assertEqual(messages, []) self.assertEqual(msg1, b"hello world 1!\n") self.assertEqual(msg2, b"hello world 2!\n")
@unittest.skipIf(sys.platform == 'win32', "Don't have pipes") def test_read_all_from_pipe_reader(self): # See asyncio issue 168. This test is derived from the example Expand Down