-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.py
175 lines (140 loc) · 5.68 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import asyncio
import struct
import json
import os
import io
import zipfile
import argparse
import logging
STORAGE_PATH = "./"
HOST, PORT = "127.0.0.1", 8888
class StatusCode:
PROTOCOL_CODES = {
"ok": struct.pack('i', 200),
"file_exists": struct.pack('i', 409),
"not_found": struct.pack('i', 404),
"bad_request": struct.pack('i', 400),
}
def __init__(self, callback):
self.callback = callback
async def ok(self):
await self.callback.send(self.PROTOCOL_CODES['ok'])
async def file_exists(self):
await self.callback.send(self.PROTOCOL_CODES['file_exists'])
async def not_found(self):
await self.callback.send(self.PROTOCOL_CODES['not_found'])
async def bad_request(self):
await self.callback.send(self.PROTOCOL_CODES['bad_request'])
class FileHandler:
def __init__(self, callback):
self.callback = callback
async def read(self, filename):
if os.path.isfile(filename):
with open(filename, 'rb') as file:
compressed_content = file.read()
header = {
"filename": os.path.basename(filename),
"filesize": len(compressed_content)
}
else:
compressed_data = io.BytesIO()
with zipfile.ZipFile(file=compressed_data, mode='w', compression=zipfile.ZIP_DEFLATED) as zip_file:
for root, _, files in os.walk(filename):
for file in files:
file_path = os.path.join(root, file)
zip_file.write(file_path, os.path.relpath(file_path, filename))
compressed_data.seek(0)
compressed_content = compressed_data.read()
header = {"filename": os.path.basename(filename) + '.zip', "filesize": compressed_data.getbuffer().nbytes}
header_data = json.dumps(header)
await self.callback.send(struct.pack('i', len(header_data)))
await self.callback.send(header_data.encode())
await self.callback.send(compressed_content)
async def save_file(self, protocol):
filename = os.path.join(STORAGE_PATH, protocol["filename"])
async def handle_data():
os.makedirs(os.path.dirname(filename), exist_ok=True)
logging.info(f"开始上传文件 {filename}")
with open(filename, 'wb') as file:
filesize = int(protocol["filesize"])
while filesize > 0:
data = await self.callback.reader.read(min(filesize, 2 ** 16))
file.write(data)
filesize -= len(data)
if os.path.exists(filename):
await self.callback.code.file_exists()
else:
await self.callback.code.ok()
await handle_data()
async def read_file(self, protocol):
filename = protocol["filename"]
filename = os.path.join(STORAGE_PATH, filename)
if not os.path.exists(filename):
await self.callback.code.not_found()
else:
await self.callback.code.ok()
logging.info(f"开始下载文件 {filename}")
await self.read(filename)
class ScriptListener:
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
self.reader = reader
self.writer = writer
self.file_handler = FileHandler(self)
self.code = StatusCode(self)
async def send(self, msg):
if not self.writer.is_closing():
self.writer.write(msg)
await self.writer.drain()
async def head(self):
head_len = await self.reader.read(4)
head_len_value = struct.unpack('i', head_len)[0]
head_struct = await self.reader.read(head_len_value)
return head_struct.decode()
async def option(self):
return await self.head()
async def upload(self):
while True:
head_struct = await self.head()
if head_struct == "over":
break
protocol = json.loads(head_struct)
await self.file_handler.save_file(protocol)
async def download(self):
protocol = json.loads(await self.head())
await self.file_handler.read_file(protocol)
# 服务器的回调函数
async def script_handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
logging.info(f"{writer.get_extra_info('peername')} 连接成功!")
handle = ScriptListener(reader, writer)
option = await handle.option()
if option == "upload":
await handle.code.ok()
await handle.upload()
elif option == 'download':
await handle.code.ok()
await handle.download()
else:
await handle.code.not_found()
writer.close()
await writer.wait_closed()
# 主函数
async def main():
server = await asyncio.start_server(script_handle, host=HOST, port=PORT)
addr = server.sockets[0].getsockname()
logging.info(f'服务开始运行 {addr} 👌')
async with server:
await server.serve_forever()
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s: %(message)s", datefmt="%H:%M:%S")
parser = argparse.ArgumentParser()
parser.add_argument("-S", "--storage", help="存储地址", default=STORAGE_PATH, type=str, dest="storage")
parser.add_argument("-H", "--host", help="ip地址", default=HOST, type=str, dest="host")
parser.add_argument("-P", "--port", help="ip端口", default=PORT, type=int, dest="port")
args = parser.parse_args()
PORT = args.port
HOST = args.host
STORAGE_PATH = args.storage
try:
asyncio.run(main())
except KeyboardInterrupt:
logging.warning("服务停止🤚")