import abc import asyncio import struct class __InternalProtocol(asyncio.Protocol): __reader: asyncio.StreamReader def __init__(self, reader: asyncio.StreamReader): assert isinstance(reader, asyncio.StreamReader) self.__reader = reader async def _drain_helper(self): pass class __InternalTransport(asyncio.WriteTransport): __close: bool __reader: asyncio.StreamReader def __init__(self, reader: asyncio.StreamReader): self.__close = False assert isinstance(reader, asyncio.StreamReader) self.__reader = reader def close(self): self.__close = True self.__reader.feed_eof() def is_closing(self): return self.__close def write(self, data: bytes): self.__reader.feed_data(data) def write_eof(self): self.close() def create_internal_stream(loop=None): if loop is None: loop = asyncio.get_event_loop() reader = asyncio.StreamReader(loop=None) internal_protocol = __InternalProtocol(reader) internal_transport = __InternalTransport(reader) writer = asyncio.StreamWriter( internal_transport, internal_protocol, reader=reader, loop=loop ) return reader, writer class _ConnectionServerProtocol(asyncio.Protocol): __writer: asyncio.StreamWriter def __init__(self, writer: asyncio.StreamWriter): assert isinstance(writer, asyncio.StreamWriter) self.__writer = writer def _drain_helper(self): return self.__writer.drain() class _ConnectionServerTransport(asyncio.WriteTransport): _close: bool __writer: asyncio.StreamWriter __id: int __format_int = struct.Struct("!I") def __init__(self, writer: asyncio.StreamWriter, id: int): self._close = False assert isinstance(writer, asyncio.StreamWriter) self.__writer = writer self.__id = id def close(self): if not self._close: self._close = True self.__writer.write(b"\x01") self.__writer.write(self.__format_int.pack(self.__id)) def is_closing(self): return self._close def write(self, data: bytes): assert len(data) <= 0xFFFFFFFF self.__writer.write(b"\x03") self.__writer.write(self.__format_int.pack(self.__id)) self.__writer.write(self.__format_int.pack(len(data))) self.__writer.write(data) def write_eof(self): self.close() class BaseConnectionServer(abc.ABC): """ Protocol-Init: send 32-Bit big-endian: Size of own configuration send ...: Configuration data (Future) receive 32-Bit big-endian: Size of remote configuration receive ...: Remote configuration data (Future) send 32-Bit big-endian: Size of own applied configuration send ...: Applied configuration receive 32-Bit big-endian: Size of remote applied configuration receive ...: Remote applied configuration (have to be the same) Protocol-configuration: (FUTURE) Protocol-applied-configuration: (FUTURE) Protocol-Messages: receive 8-Bit: Action flag Flags: - 0: Open substream receive 32-Bit big-endian: ID of the stream (ID receiver perspective) - 1: Close substream receive 32-Bit big-endian: ID of the stream (ID receiver perspective) - 2: Close applied receive 32-Bit big-endian: ID of the stream (ID receiver perspective) - 3: Write substream receive 32-Bit big-endian: ID of the stream (ID receiver perspective) receive 32-Bit big-endian: Length of content receive ...: Content IDs: 0x00000000–0x7FFFFFFF: Own opend streams 0x80000000–0xFFFFFFFF: Remote opend streams """ __in: asyncio.StreamReader __out: asyncio.StreamWriter __free_ids: set __next_id = 0 __streams: dict __size_format = struct.Struct("!I") __inited = False def __close_stream(self, id: int): # Close stream reader, _, transport = self.__streams[id] transport._close = True reader.feed_eof() # Clean up del self.__streams[id] if 0x00000000 <= id <= 0x7FFFFFFF: self.__free_ids.add(id) def __init__(self, in_io: asyncio.StreamReader, out_io: asyncio.StreamWriter): if not isinstance(in_io, asyncio.StreamReader): raise TypeError("in_io have to be a stream reader.") if not isinstance(out_io, asyncio.StreamWriter): raise TypeError("out_io have to be a stream writer.") self.__in, self.__out = in_io, out_io self.__free_ids = set() self.__streams = {} async def init(self): # Check inited if self.__inited: return else: self.__inited = True # Send own config self.__out.write(self.__size_format.pack(0)) await self.__out.drain() # Read other config size = self.__size_format.unpack( await self.__in.readexactly(self.__size_format.size) )[0] await self.__in.readexactly(size) # Send applied config self.__out.write(self.__size_format.pack(0)) await self.__out.drain() # Read other applied config size = self.__size_format.unpack( await self.__in.readexactly(self.__size_format.size) )[0] await self.__in.readexactly(size) async def open(self): # Get id if self.__free_ids: id = self.__free_ids.pop() else: id = self.__next_id self.__next_id += 1 assert id <= 0x7FFFFFFF remote_id = id | 0x80000000 # Gen internal reader = asyncio.StreamReader(loop=self.__in._loop) protocol = _ConnectionServerProtocol(self.__out) transport = _ConnectionServerTransport(self.__out, remote_id) writer = asyncio.StreamWriter( transport, protocol, reader=reader, loop=self.__out._loop ) self.__streams[id] = (reader, writer, transport) # Announce new id self.__out.write(b"\x00") self.__out.write(self.__size_format.pack(remote_id)) # Return result return reader, writer @abc.abstractmethod async def remote_opend( self, read: asyncio.StreamReader, write: asyncio.StreamWriter ): raise NotImplementedError("remote_opend isn't implemented.") async def run_server(self): # Init await self.init() # Main loop try: while True: # Read flag flag = await self.__in.read(1) if not flag: break else: flag = flag[0] assert 0 <= flag <= 3 # Action if flag == 0: # Remote open # Get id id = self.__size_format.unpack( await self.__in.readexactly(self.__size_format.size) ) assert 0x80000000 <= id <= 0xFFFFFFFF remote_id = id & 0x7FFFFFFF # Gen reader and writer reader = asyncio.StreamReader(loop=self.__in._loop) protocol = _ConnectionServerProtocol(self.__out) transport = _ConnectionServerTransport(self.__out, remote_id) writer = asyncio.StreamWriter( transport, protocol, reader=reader, loop=self.__out._loop ) self.__streams[id] = (reader, writer, transport) # Run task self.__out._loop.create_task(self.remote_opend(reader, writer)) elif flag == 1: # Remote closed stream # Get id id = self.__size_format.unpack( await self.__in.readexactly(self.__size_format.size) ) assert 0x00000000 <= id <= 0xFFFFFFFF remote_id = id ^ 0x80000000 # Remove and send applied self.__close_stream(id) self.__out.write(b"\x02") self.__out.write(self.__size_format.pack(remote_id)) elif flag == 2: # Remote applied closed stream # Get id id = self.__size_format.unpack( await self.__in.readexactly(self.__size_format.size) ) assert 0x00000000 <= id <= 0xFFFFFFFF # Close stream self.__close_stream(id) elif flag == 3: # Receive remote message # Get id and size tmp = await self.__in.readexactly(self.__size_format.size * 2) id = self.__size_format.unpack(tmp[: self.__size_format.size]) assert 0x00000000 <= id <= 0xFFFFFFFF size = self.__size_format.unpack(tmp[self.__size_format.size :]) # Get data data = await self.__in.readexactly(size) self.__streams[id][0].feed_data(data) else: raise ValueError("Unknown flag %i" % flag) finally: # Close and cleanup await self.close() async def close(self): # Close streams for i in list(self.__streams.keys()): self.__close_stream(i) # Close self.__out.write_eof() self.__out.close() await self.__out.drain()