Implement base multiplexer

master
Marko Semet 2020-08-08 21:39:42 +02:00
parent cca9962e67
commit ed1ec0265b
4 changed files with 375 additions and 117 deletions

View File

@ -0,0 +1,298 @@
import abc
import asyncio
import struct
class __InternalProtocol(asyncio.Protocol):
__reader: asyncio.StreamReader
def __init__(self, reader: asyncio.StreamReader):
assert isinstance(reader, asyncio.StreamReader)
self.__reader = reader
async def _drain_helper(self):
pass
class __InternalTransport(asyncio.WriteTransport):
__close: bool
__reader: asyncio.StreamReader
def __init__(self, reader: asyncio.StreamReader):
self.__close = False
assert isinstance(reader, asyncio.StreamReader)
self.__reader = reader
def close(self):
self.__close = True
self.__reader.feed_eof()
def is_closing(self):
return self.__close
def write(self, data: bytes):
self.__reader.feed_data(data)
def write_eof(self):
self.close()
def create_internal_stream(loop=None):
if loop is None:
loop = asyncio.get_event_loop()
reader = asyncio.StreamReader(loop=None)
internal_protocol = __InternalProtocol(reader)
internal_transport = __InternalTransport(reader)
writer = asyncio.StreamWriter(
internal_transport, internal_protocol, reader=reader, loop=loop
)
return reader, writer
class _ConnectionServerProtocol(asyncio.Protocol):
__writer: asyncio.StreamWriter
def __init__(self, writer: asyncio.StreamWriter):
assert isinstance(writer, asyncio.StreamWriter)
self.__writer = writer
def _drain_helper(self):
return self.__writer.drain()
class _ConnectionServerTransport(asyncio.WriteTransport):
_close: bool
__writer: asyncio.StreamWriter
__id: int
__format_int = struct.Struct("!I")
def __init__(self, writer: asyncio.StreamWriter, id: int):
self._close = False
assert isinstance(writer, asyncio.StreamWriter)
self.__writer = writer
self.__id = id
def close(self):
if not self._close:
self._close = True
self.__writer.write(b"\x01")
self.__writer.write(self.__format_int.pack(self.__id))
def is_closing(self):
return self._close
def write(self, data: bytes):
assert len(data) <= 0xFFFFFFFF
self.__writer.write(b"\x03")
self.__writer.write(self.__format_int.pack(self.__id))
self.__writer.write(self.__format_int.pack(len(data)))
self.__writer.write(data)
def write_eof(self):
self.close()
class BaseConnectionServer(abc.ABC):
"""
Protocol-Init:
send 32-Bit big-endian: Size of own configuration
send ...: Configuration data (Future)
receive 32-Bit big-endian: Size of remote configuration
receive ...: Remote configuration data (Future)
send 32-Bit big-endian: Size of own applied configuration
send ...: Applied configuration
receive 32-Bit big-endian: Size of remote applied configuration
receive ...: Remote applied configuration (have to be the same)
Protocol-configuration: (FUTURE)
Protocol-applied-configuration: (FUTURE)
Protocol-Messages:
receive 8-Bit: Action flag
Flags:
- 0: Open substream
receive 32-Bit big-endian: ID of the stream (ID receiver perspective)
- 1: Close substream
receive 32-Bit big-endian: ID of the stream (ID receiver perspective)
- 2: Close applied
receive 32-Bit big-endian: ID of the stream (ID receiver perspective)
- 3: Write substream
receive 32-Bit big-endian: ID of the stream (ID receiver perspective)
receive 32-Bit big-endian: Length of content
receive ...: Content
IDs:
0x000000000x7FFFFFFF: Own opend streams
0x800000000xFFFFFFFF: Remote opend streams
"""
__in: asyncio.StreamReader
__out: asyncio.StreamWriter
__free_ids: set
__next_id = 0
__streams: dict
__size_format = struct.Struct("!I")
__inited = False
def __close_stream(self, id: int):
# Close stream
reader, _, transport = self.__streams[id]
transport._close = True
reader.feed_eof()
# Clean up
del self.__streams[id]
if 0x00000000 <= id <= 0x7FFFFFFF:
self.__free_ids.add(id)
def __init__(self, in_io: asyncio.StreamReader, out_io: asyncio.StreamWriter):
if not isinstance(in_io, asyncio.StreamReader):
raise TypeError("in_io have to be a stream reader.")
if not isinstance(out_io, asyncio.StreamWriter):
raise TypeError("out_io have to be a stream writer.")
self.__in, self.__out = in_io, out_io
self.__free_ids = set()
self.__streams = {}
async def init(self):
# Check inited
if self.__inited:
return
else:
self.__inited = True
# Send own config
self.__out.write(self.__size_format.pack(0))
await self.__out.drain()
# Read other config
size = self.__size_format.unpack(
await self.__in.readexactly(self.__size_format.size)
)[0]
await self.__in.readexactly(size)
# Send applied config
self.__out.write(self.__size_format.pack(0))
await self.__out.drain()
# Read other applied config
size = self.__size_format.unpack(
await self.__in.readexactly(self.__size_format.size)
)[0]
await self.__in.readexactly(size)
async def open(self):
# Get id
if self.__free_ids:
id = self.__free_ids.pop()
else:
id = self.__next_id
self.__next_id += 1
assert id <= 0x7FFFFFFF
remote_id = id | 0x80000000
# Gen internal
reader = asyncio.StreamReader(loop=self.__in._loop)
protocol = _ConnectionServerProtocol(self.__out)
transport = _ConnectionServerTransport(self.__out, remote_id)
writer = asyncio.StreamWriter(
transport, protocol, reader=reader, loop=self.__out._loop
)
self.__streams[id] = (reader, writer, transport)
# Announce new id
self.__out.write(b"\x00")
self.__out.write(self.__size_format.pack(remote_id))
# Return result
return reader, writer
@abc.abstractmethod
async def remote_opend(
self, read: asyncio.StreamReader, write: asyncio.StreamWriter
):
raise NotImplementedError("remote_opend isn't implemented.")
async def run_server(self):
# Init
await self.init()
# Main loop
try:
while True:
# Read flag
flag = await self.__in.read(1)
if not flag:
break
else:
flag = flag[0]
assert 0 <= flag <= 3
# Action
if flag == 0: # Remote open
# Get id
id = self.__size_format.unpack(
await self.__in.readexactly(self.__size_format.size)
)
assert 0x80000000 <= id <= 0xFFFFFFFF
remote_id = id & 0x7FFFFFFF
# Gen reader and writer
reader = asyncio.StreamReader(loop=self.__in._loop)
protocol = _ConnectionServerProtocol(self.__out)
transport = _ConnectionServerTransport(self.__out, remote_id)
writer = asyncio.StreamWriter(
transport, protocol, reader=reader, loop=self.__out._loop
)
self.__streams[id] = (reader, writer, transport)
# Run task
self.__out._loop.create_task(self.remote_opend(reader, writer))
elif flag == 1: # Remote closed stream
# Get id
id = self.__size_format.unpack(
await self.__in.readexactly(self.__size_format.size)
)
assert 0x00000000 <= id <= 0xFFFFFFFF
remote_id = id ^ 0x80000000
# Remove and send applied
self.__close_stream(id)
self.__out.write(b"\x02")
self.__out.write(self.__size_format.pack(remote_id))
elif flag == 2: # Remote applied closed stream
# Get id
id = self.__size_format.unpack(
await self.__in.readexactly(self.__size_format.size)
)
assert 0x00000000 <= id <= 0xFFFFFFFF
# Close stream
self.__close_stream(id)
elif flag == 3: # Receive remote message
# Get id and size
tmp = await self.__in.readexactly(self.__size_format.size * 2)
id = self.__size_format.unpack(tmp[: self.__size_format.size])
assert 0x00000000 <= id <= 0xFFFFFFFF
size = self.__size_format.unpack(tmp[self.__size_format.size :])
# Get data
data = await self.__in.readexactly(size)
self.__streams[id][0].feed_data(data)
else:
raise ValueError("Unknown flag %i" % flag)
finally:
# Close and cleanup
await self.close()
async def close(self):
# Close streams
for i in list(self.__streams.keys()):
self.__close_stream(i)
# Close
self.__out.write_eof()
self.__out.close()
await self.__out.drain()

View File

@ -1,74 +0,0 @@
import abc
import asyncio
import struct
class __InternalProtocol(asyncio.Protocol):
__reader:asyncio.StreamReader
def __init__(self, reader:asyncio.StreamReader):
assert isinstance(reader, asyncio.StreamReader)
self.__reader = reader
async def _drain_helper(self):
pass
class __InternalTransport(asyncio.WriteTransport):
__close: bool
__reader: asyncio.StreamReader
def __init__(self, reader: asyncio.StreamReader):
self.__close = True
assert isinstance(reader, asyncio.StreamReader)
self.__reader = reader
def close(self):
self.__close = False
self.__reader.feed_eof()
def is_closing(self):
return not self.__close
def write(self, data: bytes):
self.__reader.feed_data(data)
def create_internal_stream(loop=None):
if loop is None:
loop = asyncio.get_event_loop()
reader = asyncio.StreamReader(loop=None)
internal_protocol = __InternalProtocol(reader)
internal_transport = __InternalTransport(reader)
writer = asyncio.StreamWriter(internal_transport, internal_protocol, reader=reader, loop=loop)
return reader, writer
class BaseConnectionServer(abc.ABC):
__in:asyncio.StreamReader
__out:asyncio.StreamWriter
__size_format = struct.Struct("!I")
def __init__(self, in_io:asyncio.StreamReader, out_io:asyncio.StreamWriter):
if not isinstance(in_io, asyncio.StreamReader):
raise TypeError("in_io have to be a stream reader.")
if not isinstance(out_io, asyncio.StreamWriter):
raise TypeError("out_io have to be a stream writer.")
self.__in, self.__out = in_io, out_io
async def init(self):
# Send own config
self.__out.write(self.__size_format.pack(0))
await self.__out.drain()
# Read other config
size = self.__size_format.unpack(await self.__in.readexactly(self.__size_format.size))[0]
await self.__in.readexactly(size)
# Send applied config
self.__out.write(self.__size_format.pack(0))
await self.__out.drain()
# Read other applied config
size = self.__size_format.unpack(await self.__in.readexactly(self.__size_format.size))[0]
await self.__in.readexactly(size)

View File

@ -0,0 +1,77 @@
from asynci.core import connector
import asyncio
import unittest
class TestInternal(unittest.IsolatedAsyncioTestCase):
async def test_create(self):
reader, writer = connector.create_internal_stream()
async def test_close_read(self):
reader, writer = connector.create_internal_stream()
writer.close()
await writer.drain()
self.assertEqual(await reader.read(10), b"")
async def test_write(self):
reader, writer = connector.create_internal_stream()
writer.write(b"abc")
writer.close()
await writer.drain()
self.assertEqual(await reader.read(), b"abc")
async def test_writelines(self):
reader, writer = connector.create_internal_stream()
writer.writelines([b"a", b"b"])
writer.close()
await writer.drain()
self.assertEqual(await reader.read(), b"ab")
class _BCS(connector.BaseConnectionServer):
async def remote_opend(
self, read: asyncio.StreamReader, write: asyncio.StreamWriter
):
raise NotImplementedError("remote_opend isn't implemented.")
class TestBaseConnectionServer(unittest.IsolatedAsyncioTestCase):
async def test_create(self):
reader, writer = connector.create_internal_stream()
_BCS(reader, writer)
async def test_init(self):
reader, writer = connector.create_internal_stream()
bcs = _BCS(reader, writer)
await bcs.init()
async def _gen_bcss(self, init=True):
reader2, writer1 = connector.create_internal_stream()
reader1, writer2 = connector.create_internal_stream()
bcs1 = _BCS(reader1, writer1)
bcs2 = _BCS(reader2, writer2)
if init:
t1 = asyncio.create_task(bcs1.init())
t2 = asyncio.create_task(bcs2.init())
await t1
await t2
return bcs1, bcs2
async def test_init_dyn(self):
bcs1, bcs2 = await self._gen_bcss(init=False)
t1 = asyncio.create_task(bcs1.init())
t2 = asyncio.create_task(bcs2.init())
await t1
await t2
async def test_close(self):
# Init
bcs1, bcs2 = await self._gen_bcss()
# Run server
t1 = asyncio.create_task(bcs1.run_server())
t2 = asyncio.create_task(bcs2.run_server())
await bcs1.close()
await t1
await t2

View File

@ -1,43 +0,0 @@
from asynci.core import rpc
import asyncio
import unittest
class TestInternal(unittest.IsolatedAsyncioTestCase):
async def test_create(self):
reader, writer = rpc.create_internal_stream()
async def test_close_read(self):
reader, writer = rpc.create_internal_stream()
writer.close()
await writer.drain()
self.assertEqual(await reader.read(10), b"")
async def test_write(self):
reader, writer = rpc.create_internal_stream()
writer.write(b"abc")
writer.close()
await writer.drain()
self.assertEqual(await reader.read(), b"abc")
async def test_writelines(self):
reader, writer = rpc.create_internal_stream()
writer.writelines([b"a", b"b"])
writer.close()
await writer.drain()
self.assertEqual(await reader.read(), b"ab")
class _BCS(rpc.BaseConnectionServer):
pass
class TestBaseConnectionServer(unittest.IsolatedAsyncioTestCase):
async def test_create(self):
reader, writer = rpc.create_internal_stream()
_BCS(reader, writer)
async def test_init(self):
reader, writer = rpc.create_internal_stream()
bcs = _BCS(reader, writer)
await bcs.init()