现在我们来设计一个聊天程序。在这个程序中,有一个服务端和$n$个客户端,客户端分别编号为$1\sim n$。客户端会向服务端发起两种请求:
可以注意到,服务端不会主动向客户端发送消息。服务端接收消息也需要自身发起请求,从服务器的响应中获取消息。由于服务端是即时响应的,因此连接的过程很短,这称之为短连接。
相对应的,存在长连接:在这种情况下,客户端一旦连接了服务端,就会试图一直保持连接。由于连接一直保持,因此服务端一旦接收到消息,就可以立刻从保持的连接发送给指定客户端。
然而,长连接并不稳定。由于网络环境多变,长连接可能断线而不自知。这需要通过定时发送心跳数据包,来检测连接是否仍然保持。一旦断线,则立刻试图重新连接。
我们在此规定两种请求以及对应的响应的格式:
请求格式:S\n本机ID\n目标ID\n发送内容,其中\n表示换行符。例如,S\n100\n1\nHello表示100号客户端向1号客户端发送信息Hello。当目标ID为0时,表示向所有客户端发送这条消息。
响应格式:Y\n表示成功,N\ninvalid\n表示失败。
请求格式:R\n本机ID,其中\n表示换行符。例如,R\n1表示1号客户端接收信息。
响应格式:N\ninvalid\n表示失败。成功时,返回多行内容:
第1行为字符Y,第2行为消息条数$n$;
第$(2i+2)$行为第$i$条消息的发送端ID;
第$(2i+3)$行为第$i$条消息的内容。
根据上述需求,我们可以写出以下代码。
为了稳定性和可用性考虑,服务端对于每个客户端的消息存储数量设置了上限。消息也设置了超时时间,以免接受到过于古早的消息。具体代码如下:
import asyncio,time
from collections import deque
from typing import Optional
MAX_CONNECTION = 128 # 最大可用127个ID,1~127
MAX_QUE_LEN = 32 # 每个ID最多保存的消息数量为32
DEF_EXPIRE = 180 # 消息超时时间为180秒
PORT = 8888 # 服务运行端口
class Message: # 一条消息
time:float
content:str
def __init__(self,content:str,mtime:float=0):
self.time=time.time() if mtime==0 else mtime
self.content=content
class MessageQue: # 某个客户端的消息队列
_que:deque[Message]
_len:int
_mlen:int
def __init__(self,max_length:Optional[int]=None):
self._mlen=max_length
self._que=deque(maxlen=self._mlen)
self._len=0
def push(self,text:str):
self._que.append(Message(text))
self._len+=1
def pop(self,timelim:float=0)->str:
ret=None
while ret!=None and self._len>0:
ret=self._que.popleft()
self._len-=1
if ret.time<timelim: ret=None
return ret.content
def popall(self,timelim:float=0)->list[str]:
ret=[]
while self._len>0:
itm=self._que.popleft()
self._len-=1
if itm.time>=timelim:
ret.append(itm.content)
return ret
def clear(self,timelim:float=time.time()):
while self._que[0].time<timelim:
self._que.popleft()
self._len-=1
class MessagePool: # 所有客户端的消息队列构成消息池
max_id:int
messages:list[MessageQue]
default_expire:float
def __init__(self,max_id:int,max_que_len:int,default_expire:float):
self.max_id=max_id
self.default_expire=default_expire
self.messages=[MessageQue(max_que_len) for _ in range(max_id)]
def clear(self,timelim:float=0):
if timelim==0: timelim=time.time()-self.default_expire
for m in self.messages:
m.clear(timelim)
def pop(self,id:int):
return self.messages[id].popall(time.time()-self.default_expire)
def push(self,id:int,content:str):
self.messages[id].push(content)
def __getitem__(self,index):
return self.messages[index]
MPOOL = MessagePool(MAX_CONNECTION,MAX_QUE_LEN,DEF_EXPIRE)
async def handle_echo(reader:asyncio.StreamReader, writer:asyncio.StreamWriter):
addr = writer.get_extra_info('peername')
to_print=f"{addr} start"
print(to_print)
data = await reader.read(1024)
message = data.decode().split('\n')
op = message[0]
try:
this_id = int(message[1])
except:
this_id = -1
to_print=f"{addr} op = {op}, this_id = {this_id}"
to_write=f"N\ninvalid\n"
if this_id > 0 and this_id < MPOOL.max_id:
if op == 'S':
try:
to_id = int(message[2])
except:
to_id = -1
try:
text = message[3]
except:
text = None
if to_id >=0 and to_id<MPOOL.max_id and text!=None:
to_print+=f", to_id = {to_id}"
if to_id==0:
for i in range(1,MPOOL.max_id):
if i!=this_id: MPOOL.push(i,f"{this_id}\n{text}")
else:
MPOOL.push(to_id,f"{this_id}\n{text}")
to_write="Y\n"
elif op == 'R':
if this_id>0 or this_id<MAX_CONNECTION:
messages=MPOOL.pop(this_id)
cnt=len(messages)
to_write=f"Y\n{cnt}\n"+'\n'.join(messages)+"\n"
to_print+=f", count = {cnt}"
print(to_print)
writer.write(to_write.encode())
await writer.drain()
to_print=f"{addr} end"
print(to_print)
writer.close()
loop = asyncio.get_event_loop()
coro = asyncio.start_server(handle_echo, '', PORT, loop=loop)
server = loop.run_until_complete(coro)
print('Serving on {}'.format(server.sockets[0].getsockname()))
try:
loop.run_forever()
except KeyboardInterrupt:
pass
server.close()
loop.run_until_complete(server.wait_closed())
loop.close()
# 需要将本教程第3节中的tcp.py放在同一文件夹下!
from typing import Optional
import time
from tcp import *
class chatclient:
def __init__(self,server_addr,local_port,my_id):
self.server_addr=server_addr
self.my_id=my_id
self.local_port=local_port
def send(self,to_id,text)->bool:
clnt=tcpclient(self.local_port)
if clnt.connect(self.server_addr):
time.sleep(0.1)
if clnt.send(f"S\n{self.my_id}\n{to_id}\n{text}\n".encode()):
time.sleep(0.1)
data,_=clnt.recv()
clnt.close()
data=data.decode().split('\n')
return data[0].strip()=='Y'
return False
def recv(self)->Optional[list[tuple[int,str]]]:
clnt=tcpclient(self.local_port)
if clnt.connect(self.server_addr):
time.sleep(0.1)
if clnt.send(f"R\n{self.my_id}\n".encode()):
time.sleep(0.1)
data,_=clnt.recv()
clnt.close()
res=data.decode().split('\n')
if res[0].strip()=='Y':
cnt=int(res[1].strip())
ret=[]
for i in range(cnt):
ret.append((res[i*2+2],res[i*2+3]))
return ret
return None
clnt=chatclient(('127.0.0.1',8888),8080,12)
if clnt.send(12,"Hello!"):
print("Send OK!")
time.sleep(0.1)
messages=clnt.recv()
if messages!=None:
print("Receive OK!")
for (from_id,text) in messages:
print(f"from = {from_id}, text = {text}")
else:
print("Fail to receive!")
else:
print("Fail to send!")