Tags   Server TCP async

Back

网络通信技术实践: (6)短连接聊天程序范例

1 项目简介

现在我们来设计一个聊天程序。在这个程序中,有一个服务端和$n$个客户端,客户端分别编号为$1\sim n$。客户端会向服务端发起两种请求:

  1. 发送消息给指定客户端;
  2. 接收本客户端的消息。

可以注意到,服务端不会主动向客户端发送消息。服务端接收消息也需要自身发起请求,从服务器的响应中获取消息。由于服务端是即时响应的,因此连接的过程很短,这称之为短连接

相对应的,存在长连接:在这种情况下,客户端一旦连接了服务端,就会试图一直保持连接。由于连接一直保持,因此服务端一旦接收到消息,就可以立刻从保持的连接发送给指定客户端。

然而,长连接并不稳定。由于网络环境多变,长连接可能断线而不自知。这需要通过定时发送心跳数据包,来检测连接是否仍然保持。一旦断线,则立刻试图重新连接。

我们在此规定两种请求以及对应的响应的格式:

1. 发送消息给指定客户端

请求格式:S\n本机ID\n目标ID\n发送内容,其中\n表示换行符。例如,S\n100\n1\nHello表示100号客户端向1号客户端发送信息Hello。当目标ID为0时,表示向所有客户端发送这条消息。

响应格式:Y\n表示成功,N\ninvalid\n表示失败。

2. 接收本客户端的消息

请求格式:R\n本机ID,其中\n表示换行符。例如,R\n1表示1号客户端接收信息。

响应格式:N\ninvalid\n表示失败。成功时,返回多行内容:

第1行为字符Y,第2行为消息条数$n$;

第$(2i+2)$行为第$i$条消息的发送端ID;

第$(2i+3)$行为第$i$条消息的内容。

根据上述需求,我们可以写出以下代码。

2 服务端代码

为了稳定性和可用性考虑,服务端对于每个客户端的消息存储数量设置了上限。消息也设置了超时时间,以免接受到过于古早的消息。具体代码如下:

import asyncio,time
from collections import deque
from typing import Optional

MAX_CONNECTION = 128    # 最大可用127个ID,1~127
MAX_QUE_LEN = 32        # 每个ID最多保存的消息数量为32
DEF_EXPIRE = 180        # 消息超时时间为180秒
PORT = 8888             # 服务运行端口

class Message: # 一条消息
    time:float
    content:str
    def __init__(self,content:str,mtime:float=0):
        self.time=time.time() if mtime==0 else mtime
        self.content=content

class MessageQue: # 某个客户端的消息队列
    _que:deque[Message]
    _len:int
    _mlen:int
    def __init__(self,max_length:Optional[int]=None):
        self._mlen=max_length
        self._que=deque(maxlen=self._mlen)
        self._len=0
    def push(self,text:str):
        self._que.append(Message(text))
        self._len+=1
    def pop(self,timelim:float=0)->str:
        ret=None
        while ret!=None and self._len>0:
            ret=self._que.popleft()
            self._len-=1
            if ret.time<timelim: ret=None
        return ret.content
    def popall(self,timelim:float=0)->list[str]:
        ret=[]
        while self._len>0:
            itm=self._que.popleft()
            self._len-=1
            if itm.time>=timelim:
                ret.append(itm.content)
        return ret
    def clear(self,timelim:float=time.time()):
        while self._que[0].time<timelim:
            self._que.popleft()
            self._len-=1

class MessagePool: # 所有客户端的消息队列构成消息池
    max_id:int
    messages:list[MessageQue]
    default_expire:float
    def __init__(self,max_id:int,max_que_len:int,default_expire:float):
        self.max_id=max_id
        self.default_expire=default_expire
        self.messages=[MessageQue(max_que_len) for _ in range(max_id)]
    def clear(self,timelim:float=0):
        if timelim==0: timelim=time.time()-self.default_expire
        for m in self.messages:
            m.clear(timelim)
    def pop(self,id:int):
        return self.messages[id].popall(time.time()-self.default_expire)
    def push(self,id:int,content:str):
        self.messages[id].push(content)
    def __getitem__(self,index):
        return self.messages[index]

MPOOL = MessagePool(MAX_CONNECTION,MAX_QUE_LEN,DEF_EXPIRE)

async def handle_echo(reader:asyncio.StreamReader, writer:asyncio.StreamWriter):
    addr = writer.get_extra_info('peername')
    to_print=f"{addr} start"
    print(to_print)

    data = await reader.read(1024)
    message = data.decode().split('\n')
    op = message[0]
    try:
        this_id = int(message[1])
    except:
        this_id = -1
    to_print=f"{addr} op = {op}, this_id = {this_id}"

    to_write=f"N\ninvalid\n"
    if this_id > 0 and this_id < MPOOL.max_id:
        if op == 'S':
            try:
                to_id = int(message[2])
            except:
                to_id = -1
            try:
                text = message[3]
            except:
                text = None
            if to_id >=0 and to_id<MPOOL.max_id and text!=None:
                to_print+=f", to_id = {to_id}"
                if to_id==0:
                    for i in range(1,MPOOL.max_id):
                        if i!=this_id: MPOOL.push(i,f"{this_id}\n{text}")
                else:
                    MPOOL.push(to_id,f"{this_id}\n{text}")
                to_write="Y\n"
        elif op == 'R':
            if this_id>0 or this_id<MAX_CONNECTION:
                messages=MPOOL.pop(this_id)
                cnt=len(messages)
                to_write=f"Y\n{cnt}\n"+'\n'.join(messages)+"\n"
                to_print+=f", count = {cnt}"
    print(to_print)
    writer.write(to_write.encode())
    await writer.drain()
    to_print=f"{addr} end"
    print(to_print)
    writer.close()

loop = asyncio.get_event_loop()
coro = asyncio.start_server(handle_echo, '', PORT, loop=loop)
server = loop.run_until_complete(coro)

print('Serving on {}'.format(server.sockets[0].getsockname()))
try:
    loop.run_forever()
except KeyboardInterrupt:
    pass

server.close()
loop.run_until_complete(server.wait_closed())
loop.close()

3 客户端代码

# 需要将本教程第3节中的tcp.py放在同一文件夹下!
from typing import Optional
import time
from tcp import *

class chatclient:
    def __init__(self,server_addr,local_port,my_id):
        self.server_addr=server_addr
        self.my_id=my_id
        self.local_port=local_port
    
    def send(self,to_id,text)->bool:
        clnt=tcpclient(self.local_port)
        if clnt.connect(self.server_addr):
            time.sleep(0.1)
            if clnt.send(f"S\n{self.my_id}\n{to_id}\n{text}\n".encode()):
                time.sleep(0.1)
                data,_=clnt.recv()
                clnt.close()
                data=data.decode().split('\n')
                return data[0].strip()=='Y'
        return False
    
    def recv(self)->Optional[list[tuple[int,str]]]:
        clnt=tcpclient(self.local_port)
        if clnt.connect(self.server_addr):
            time.sleep(0.1)
            if clnt.send(f"R\n{self.my_id}\n".encode()):
                time.sleep(0.1)
                data,_=clnt.recv()
                clnt.close()
                res=data.decode().split('\n')
                if res[0].strip()=='Y':
                    cnt=int(res[1].strip())
                    ret=[]
                    for i in range(cnt):
                        ret.append((res[i*2+2],res[i*2+3]))
                    return ret
        return None

clnt=chatclient(('127.0.0.1',8888),8080,12)
if clnt.send(12,"Hello!"):
    print("Send OK!")
    time.sleep(0.1)
    messages=clnt.recv()
    if messages!=None:
        print("Receive OK!")
        for (from_id,text) in messages:
            print(f"from = {from_id}, text = {text}")
    else:
        print("Fail to receive!")
else:
    print("Fail to send!")

目录

上一节

下一节