asynci/asynci/core/connector.py

299 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import abc
import asyncio
import struct
class __InternalProtocol(asyncio.Protocol):
__reader: asyncio.StreamReader
def __init__(self, reader: asyncio.StreamReader):
assert isinstance(reader, asyncio.StreamReader)
self.__reader = reader
async def _drain_helper(self):
pass
class __InternalTransport(asyncio.WriteTransport):
__close: bool
__reader: asyncio.StreamReader
def __init__(self, reader: asyncio.StreamReader):
self.__close = False
assert isinstance(reader, asyncio.StreamReader)
self.__reader = reader
def close(self):
self.__close = True
self.__reader.feed_eof()
def is_closing(self):
return self.__close
def write(self, data: bytes):
self.__reader.feed_data(data)
def write_eof(self):
self.close()
def create_internal_stream(loop=None):
if loop is None:
loop = asyncio.get_event_loop()
reader = asyncio.StreamReader(loop=None)
internal_protocol = __InternalProtocol(reader)
internal_transport = __InternalTransport(reader)
writer = asyncio.StreamWriter(
internal_transport, internal_protocol, reader=reader, loop=loop
)
return reader, writer
class _ConnectionServerProtocol(asyncio.Protocol):
__writer: asyncio.StreamWriter
def __init__(self, writer: asyncio.StreamWriter):
assert isinstance(writer, asyncio.StreamWriter)
self.__writer = writer
def _drain_helper(self):
return self.__writer.drain()
class _ConnectionServerTransport(asyncio.WriteTransport):
_close: bool
__writer: asyncio.StreamWriter
__id: int
__format_int = struct.Struct("!I")
def __init__(self, writer: asyncio.StreamWriter, id: int):
self._close = False
assert isinstance(writer, asyncio.StreamWriter)
self.__writer = writer
self.__id = id
def close(self):
if not self._close:
self._close = True
self.__writer.write(b"\x01")
self.__writer.write(self.__format_int.pack(self.__id))
def is_closing(self):
return self._close
def write(self, data: bytes):
assert len(data) <= 0xFFFFFFFF
self.__writer.write(b"\x03")
self.__writer.write(self.__format_int.pack(self.__id))
self.__writer.write(self.__format_int.pack(len(data)))
self.__writer.write(data)
def write_eof(self):
self.close()
class BaseConnectionServer(abc.ABC):
"""
Protocol-Init:
send 32-Bit big-endian: Size of own configuration
send ...: Configuration data (Future)
receive 32-Bit big-endian: Size of remote configuration
receive ...: Remote configuration data (Future)
send 32-Bit big-endian: Size of own applied configuration
send ...: Applied configuration
receive 32-Bit big-endian: Size of remote applied configuration
receive ...: Remote applied configuration (have to be the same)
Protocol-configuration: (FUTURE)
Protocol-applied-configuration: (FUTURE)
Protocol-Messages:
receive 8-Bit: Action flag
Flags:
- 0: Open substream
receive 32-Bit big-endian: ID of the stream (ID receiver perspective)
- 1: Close substream
receive 32-Bit big-endian: ID of the stream (ID receiver perspective)
- 2: Close applied
receive 32-Bit big-endian: ID of the stream (ID receiver perspective)
- 3: Write substream
receive 32-Bit big-endian: ID of the stream (ID receiver perspective)
receive 32-Bit big-endian: Length of content
receive ...: Content
IDs:
0x000000000x7FFFFFFF: Own opend streams
0x800000000xFFFFFFFF: Remote opend streams
"""
__in: asyncio.StreamReader
__out: asyncio.StreamWriter
__free_ids: set
__next_id = 0
__streams: dict
__size_format = struct.Struct("!I")
__inited = False
def __close_stream(self, id: int):
# Close stream
reader, _, transport = self.__streams[id]
transport._close = True
reader.feed_eof()
# Clean up
del self.__streams[id]
if 0x00000000 <= id <= 0x7FFFFFFF:
self.__free_ids.add(id)
def __init__(self, in_io: asyncio.StreamReader, out_io: asyncio.StreamWriter):
if not isinstance(in_io, asyncio.StreamReader):
raise TypeError("in_io have to be a stream reader.")
if not isinstance(out_io, asyncio.StreamWriter):
raise TypeError("out_io have to be a stream writer.")
self.__in, self.__out = in_io, out_io
self.__free_ids = set()
self.__streams = {}
async def init(self):
# Check inited
if self.__inited:
return
else:
self.__inited = True
# Send own config
self.__out.write(self.__size_format.pack(0))
await self.__out.drain()
# Read other config
size = self.__size_format.unpack(
await self.__in.readexactly(self.__size_format.size)
)[0]
await self.__in.readexactly(size)
# Send applied config
self.__out.write(self.__size_format.pack(0))
await self.__out.drain()
# Read other applied config
size = self.__size_format.unpack(
await self.__in.readexactly(self.__size_format.size)
)[0]
await self.__in.readexactly(size)
async def open(self):
# Get id
if self.__free_ids:
id = self.__free_ids.pop()
else:
id = self.__next_id
self.__next_id += 1
assert id <= 0x7FFFFFFF
remote_id = id | 0x80000000
# Gen internal
reader = asyncio.StreamReader(loop=self.__in._loop)
protocol = _ConnectionServerProtocol(self.__out)
transport = _ConnectionServerTransport(self.__out, remote_id)
writer = asyncio.StreamWriter(
transport, protocol, reader=reader, loop=self.__out._loop
)
self.__streams[id] = (reader, writer, transport)
# Announce new id
self.__out.write(b"\x00")
self.__out.write(self.__size_format.pack(remote_id))
# Return result
return reader, writer
@abc.abstractmethod
async def remote_opend(
self, read: asyncio.StreamReader, write: asyncio.StreamWriter
):
raise NotImplementedError("remote_opend isn't implemented.")
async def run_server(self):
# Init
await self.init()
# Main loop
try:
while True:
# Read flag
flag = await self.__in.read(1)
if not flag:
break
else:
flag = flag[0]
assert 0 <= flag <= 3
# Action
if flag == 0: # Remote open
# Get id
id = self.__size_format.unpack(
await self.__in.readexactly(self.__size_format.size)
)
assert 0x80000000 <= id <= 0xFFFFFFFF
remote_id = id & 0x7FFFFFFF
# Gen reader and writer
reader = asyncio.StreamReader(loop=self.__in._loop)
protocol = _ConnectionServerProtocol(self.__out)
transport = _ConnectionServerTransport(self.__out, remote_id)
writer = asyncio.StreamWriter(
transport, protocol, reader=reader, loop=self.__out._loop
)
self.__streams[id] = (reader, writer, transport)
# Run task
self.__out._loop.create_task(self.remote_opend(reader, writer))
elif flag == 1: # Remote closed stream
# Get id
id = self.__size_format.unpack(
await self.__in.readexactly(self.__size_format.size)
)
assert 0x00000000 <= id <= 0xFFFFFFFF
remote_id = id ^ 0x80000000
# Remove and send applied
self.__close_stream(id)
self.__out.write(b"\x02")
self.__out.write(self.__size_format.pack(remote_id))
elif flag == 2: # Remote applied closed stream
# Get id
id = self.__size_format.unpack(
await self.__in.readexactly(self.__size_format.size)
)
assert 0x00000000 <= id <= 0xFFFFFFFF
# Close stream
self.__close_stream(id)
elif flag == 3: # Receive remote message
# Get id and size
tmp = await self.__in.readexactly(self.__size_format.size * 2)
id = self.__size_format.unpack(tmp[: self.__size_format.size])
assert 0x00000000 <= id <= 0xFFFFFFFF
size = self.__size_format.unpack(tmp[self.__size_format.size :])
# Get data
data = await self.__in.readexactly(size)
self.__streams[id][0].feed_data(data)
else:
raise ValueError("Unknown flag %i" % flag)
finally:
# Close and cleanup
await self.close()
async def close(self):
# Close streams
for i in list(self.__streams.keys()):
self.__close_stream(i)
# Close
self.__out.write_eof()
self.__out.close()
await self.__out.drain()