[feat] 更新到0.2.3版本,针对ZJ-ERNC对一些功能进行了重构
This commit is contained in:
parent
1ba9cb7130
commit
04c8b4ce65
BIN
dist/param_service-0.2.2-py3-none-any.whl
vendored
Normal file
BIN
dist/param_service-0.2.2-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/param_service-0.2.3-py3-none-any.whl
vendored
Normal file
BIN
dist/param_service-0.2.3-py3-none-any.whl
vendored
Normal file
Binary file not shown.
@ -1,8 +1,12 @@
|
||||
Metadata-Version: 2.1
|
||||
Metadata-Version: 2.4
|
||||
Name: param_service
|
||||
Version: 0.2.1
|
||||
Version: 0.2.3
|
||||
Summary: Write/Read param from server
|
||||
Author: CuiJingwei
|
||||
Author-email: cuijingwei@brisonus.com
|
||||
Requires-Dist: numpy
|
||||
Requires-Dist: PySide6
|
||||
Dynamic: author
|
||||
Dynamic: author-email
|
||||
Dynamic: requires-dist
|
||||
Dynamic: summary
|
||||
|
@ -1,6 +1,7 @@
|
||||
setup.py
|
||||
param_service/__init__.py
|
||||
param_service/params_service.py
|
||||
param_service/test_params_service.py
|
||||
param_service.egg-info/PKG-INFO
|
||||
param_service.egg-info/SOURCES.txt
|
||||
param_service.egg-info/dependency_links.txt
|
||||
|
@ -3,14 +3,14 @@ import string
|
||||
import json
|
||||
import time
|
||||
import queue
|
||||
import socket
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Dict, Callable
|
||||
|
||||
from PySide6.QtCore import QObject, Signal, Slot, QTimer
|
||||
from PySide6.QtNetwork import QTcpSocket
|
||||
from PySide6.QtCore import QByteArray
|
||||
from typing import Any, Optional, Dict, Callable, List, Union
|
||||
import signal
|
||||
import sys
|
||||
|
||||
|
||||
class CMD(Enum):
|
||||
@ -22,7 +22,6 @@ class CMD(Enum):
|
||||
class Request:
|
||||
token: str
|
||||
cmd: CMD
|
||||
widget: QObject
|
||||
data: Any
|
||||
callback: Optional[Callable] = None
|
||||
created_at: datetime = None
|
||||
@ -35,21 +34,16 @@ class Request:
|
||||
def is_expired(self) -> bool:
|
||||
return datetime.now() > self.created_at + timedelta(seconds=self.timeout)
|
||||
|
||||
|
||||
@dataclass()
|
||||
class Response:
|
||||
token: str
|
||||
cmd: CMD
|
||||
widget: QObject
|
||||
data: Any
|
||||
|
||||
|
||||
class ParamsService(QObject):
|
||||
signal_request_complete = Signal(object) # 请求完成信号
|
||||
signal_connection_status = Signal(bool) # 连接状态信号
|
||||
signal_error = Signal(str) # 错误信号
|
||||
|
||||
def __init__(self, host: str, port: int, parent=None):
|
||||
super().__init__(parent)
|
||||
class ParamsService:
|
||||
def __init__(self, host: str, port: int):
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
@ -63,118 +57,98 @@ class ParamsService(QObject):
|
||||
self.pending_requests: Dict[str, Request] = {}
|
||||
|
||||
# 初始化socket
|
||||
self.socket = QTcpSocket(self)
|
||||
self._setup_socket_connections()
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.socket.settimeout(5) # 设置超时时间
|
||||
|
||||
# 初始化定时器
|
||||
self._request_timer = QTimer(self)
|
||||
self._request_timer.timeout.connect(self._process_next_request)
|
||||
self._request_timer.setInterval(100) # 100ms间隔
|
||||
# 回调函数
|
||||
self._on_request_complete = None
|
||||
self._on_connection_status = None
|
||||
self._on_error = None
|
||||
|
||||
# 重连定时器
|
||||
self._reconnect_timer = QTimer(self)
|
||||
self._reconnect_timer.timeout.connect(self._try_reconnect)
|
||||
self._reconnect_timer.setInterval(5000) # 5秒重连间隔
|
||||
# 启动处理线程
|
||||
self._process_thread = threading.Thread(target=self._process_requests, daemon=True)
|
||||
self._process_thread.start()
|
||||
|
||||
# 超时检查定时器
|
||||
self._timeout_timer = QTimer(self)
|
||||
self._timeout_timer.timeout.connect(self._check_timeouts)
|
||||
self._timeout_timer.setInterval(1000) # 1秒检查一次
|
||||
# 启动接收线程
|
||||
self._receive_thread = threading.Thread(target=self._receive_loop, daemon=True)
|
||||
self._receive_thread.start()
|
||||
|
||||
# 启动定时器
|
||||
self._request_timer.start()
|
||||
self._timeout_timer.start()
|
||||
# 启动超时检查线程
|
||||
self._timeout_thread = threading.Thread(target=self._check_timeouts, daemon=True)
|
||||
self._timeout_thread.start()
|
||||
|
||||
# 首次连接
|
||||
self.connect_to_server()
|
||||
|
||||
def _setup_socket_connections(self):
|
||||
"""设置socket信号连接"""
|
||||
self.socket.connected.connect(self._on_connected)
|
||||
self.socket.disconnected.connect(self._on_disconnected)
|
||||
self.socket.readyRead.connect(self._on_ready_read)
|
||||
self.socket.errorOccurred.connect(self._on_socket_error)
|
||||
def set_callbacks(self, on_request_complete: Callable = None,
|
||||
on_connection_status: Callable = None,
|
||||
on_error: Callable = None):
|
||||
"""设置回调函数"""
|
||||
self._on_request_complete = on_request_complete
|
||||
self._on_connection_status = on_connection_status
|
||||
self._on_error = on_error
|
||||
|
||||
def connect_to_server(self):
|
||||
"""连接到服务器"""
|
||||
if not self._connected:
|
||||
self.socket.connectToHost(self.host, self.port)
|
||||
if not self._connected and self._is_running: # 只有在服务运行时才连接
|
||||
try:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.socket.settimeout(5)
|
||||
self.socket.connect((self.host, self.port))
|
||||
self._connected = True
|
||||
if self._on_connection_status:
|
||||
self._on_connection_status(True)
|
||||
print(f"Connected to {self.host}:{self.port}")
|
||||
except Exception as e:
|
||||
print(f"Connection error: {e}")
|
||||
if self._on_error:
|
||||
self._on_error(f"Connection error: {str(e)}")
|
||||
if self.socket:
|
||||
try:
|
||||
self.socket.close()
|
||||
except:
|
||||
pass
|
||||
self.socket = None
|
||||
self._schedule_reconnect()
|
||||
|
||||
@Slot()
|
||||
def _on_connected(self):
|
||||
"""连接成功处理"""
|
||||
print(f"Connected to {self.host}:{self.port}")
|
||||
self._connected = True
|
||||
self._reconnect_timer.stop()
|
||||
self.signal_connection_status.emit(True)
|
||||
def _schedule_reconnect(self):
|
||||
"""安排重连"""
|
||||
if not self._connected and self._is_running: # 只有在服务运行时才重连
|
||||
threading.Timer(5.0, self.connect_to_server).start()
|
||||
|
||||
@Slot()
|
||||
def _on_disconnected(self):
|
||||
"""断开连接处理"""
|
||||
print("Disconnected from server")
|
||||
self._connected = False
|
||||
self.signal_connection_status.emit(False)
|
||||
self._reconnect_timer.start() # 启动重连定时器
|
||||
def _process_requests(self):
|
||||
"""处理请求的主循环"""
|
||||
while self._is_running:
|
||||
try:
|
||||
if not self._connected or self.request_queue.empty() or self.pending_requests:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
@Slot()
|
||||
def _on_socket_error(self):
|
||||
"""Socket错误处理"""
|
||||
error = self.socket.errorString()
|
||||
print(f"Socket error: {error}")
|
||||
self.signal_error.emit(f"Socket error: {error}")
|
||||
|
||||
@Slot()
|
||||
def _on_ready_read(self):
|
||||
"""数据接收处理"""
|
||||
try:
|
||||
data = self.socket.readAll()
|
||||
response = json.loads(bytes(data).decode())
|
||||
self._handle_response(response)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON decode error: {e}")
|
||||
self.signal_error.emit(f"Invalid JSON format: {str(e)}")
|
||||
except Exception as e:
|
||||
print(f"Error processing response: {e}")
|
||||
self.signal_error.emit(f"Response processing error: {str(e)}")
|
||||
|
||||
@Slot()
|
||||
def _process_next_request(self):
|
||||
"""处理队列中的下一个请求"""
|
||||
# 如果未连接、队列为空,或者当前有正在处理的请求,则返回
|
||||
if not self._connected or self.request_queue.empty() or self.pending_requests:
|
||||
return
|
||||
|
||||
try:
|
||||
# 获取但不移除请求
|
||||
request = self.request_queue.get()
|
||||
self._current_request = request
|
||||
self._send_request(request)
|
||||
time.sleep(0.1)
|
||||
except queue.Empty:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Error processing request: {e}")
|
||||
self.signal_error.emit(f"Request processing error: {str(e)}")
|
||||
# 发生错误时,确保清理当前请求
|
||||
if self._current_request:
|
||||
self.request_queue.task_done()
|
||||
self._current_request = None
|
||||
request = self.request_queue.get()
|
||||
self._current_request = request
|
||||
self._send_request(request)
|
||||
time.sleep(0.1)
|
||||
except Exception as e:
|
||||
print(f"Error processing request: {e}")
|
||||
if self._on_error:
|
||||
self._on_error(f"Request processing error: {str(e)}")
|
||||
if self._current_request:
|
||||
self.request_queue.task_done()
|
||||
self._current_request = None
|
||||
|
||||
def _send_request(self, request: Request):
|
||||
"""发送请求到服务器"""
|
||||
try:
|
||||
print(f"Sending request with token: {request.token}")
|
||||
match request.cmd:
|
||||
case CMD.GET_PARAMS:
|
||||
self.pending_requests[request.token] = request
|
||||
|
||||
request_data = {
|
||||
"cmd": "get_params",
|
||||
"token": request.token,
|
||||
"data": request.data
|
||||
}
|
||||
json_data = json.dumps(request_data)
|
||||
self.socket.write(json_data.encode('utf-8')+b'\0')
|
||||
self.socket.flush()
|
||||
self._send_json(request_data)
|
||||
case CMD.SET_PARAMS:
|
||||
self.pending_requests[request.token] = request
|
||||
request_data = {
|
||||
@ -182,98 +156,241 @@ class ParamsService(QObject):
|
||||
"token": request.token,
|
||||
"data": request.data
|
||||
}
|
||||
json_data = json.dumps(request_data)
|
||||
self.socket.write(json_data.encode('utf-8')+b'\0')
|
||||
self.socket.flush()
|
||||
|
||||
self._send_json(request_data)
|
||||
print(f"Request sent successfully: {request.token}")
|
||||
except Exception as e:
|
||||
print(f"Error sending request: {e}")
|
||||
self.signal_error.emit(f"Request sending error: {str(e)}")
|
||||
if self._on_error:
|
||||
self._on_error(f"Request sending error: {str(e)}")
|
||||
self.pending_requests.pop(request.token, None)
|
||||
if request.callback:
|
||||
request.callback({"error": str(e), "token": request.token})
|
||||
|
||||
def _send_json(self, data: dict):
|
||||
"""发送JSON数据"""
|
||||
json_data = json.dumps(data).encode('utf-8') + b'\n' # 使用\n作为分隔符
|
||||
self.socket.sendall(json_data)
|
||||
|
||||
def _receive_loop(self):
|
||||
"""接收数据的循环"""
|
||||
while self._is_running:
|
||||
try:
|
||||
if not self._connected:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
data = self._receive_data()
|
||||
if data:
|
||||
print(f"Received data: {data}")
|
||||
self._handle_response(data)
|
||||
except socket.timeout:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"Error in receive loop: {e}")
|
||||
if self._on_error:
|
||||
self._on_error(f"Receive loop error: {str(e)}")
|
||||
time.sleep(0.1)
|
||||
|
||||
def _receive_data(self) -> Optional[dict]:
|
||||
"""接收数据"""
|
||||
try:
|
||||
data = b''
|
||||
while True:
|
||||
chunk = self.socket.recv(4096)
|
||||
if not chunk:
|
||||
break
|
||||
data += chunk
|
||||
if b'\n' in chunk: # 检查是否收到完整消息
|
||||
break
|
||||
if data:
|
||||
print(f"Raw received data: {data}")
|
||||
return json.loads(data.decode().rstrip('\n'))
|
||||
except socket.timeout:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error receiving data: {e}")
|
||||
if self._on_error:
|
||||
self._on_error(f"Data receiving error: {str(e)}")
|
||||
return None
|
||||
|
||||
def _handle_response(self, response: dict):
|
||||
"""处理服务器响应"""
|
||||
try:
|
||||
token = response.get("token")
|
||||
print(f"Handling response for token: {token}")
|
||||
if token in self.pending_requests:
|
||||
request = self.pending_requests.pop(token)
|
||||
res_data = ''
|
||||
res_data = response.get("data", {})
|
||||
print(f"Calling callback for token: {token}")
|
||||
|
||||
# 调用回调函数
|
||||
if request.callback:
|
||||
res_data = response["data"]
|
||||
|
||||
res = Response(token, CMD.GET_PARAMS, request.widget, res_data)
|
||||
res = Response(token, request.cmd, res_data)
|
||||
request.callback(res)
|
||||
|
||||
self.signal_request_complete.emit(response)
|
||||
if self._on_request_complete:
|
||||
self._on_request_complete(response)
|
||||
|
||||
# 完成当前请求的处理
|
||||
if self._current_request and self._current_request.token == token:
|
||||
self.request_queue.task_done()
|
||||
self._current_request = None
|
||||
else:
|
||||
print(f"Received response for unknown token: {token}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling response: {e}")
|
||||
self.signal_error.emit(f"Response handling error: {str(e)}")
|
||||
if self._on_error:
|
||||
self._on_error(f"Response handling error: {str(e)}")
|
||||
# 确保在错误情况下也调用回调
|
||||
if token in self.pending_requests:
|
||||
request = self.pending_requests.pop(token)
|
||||
if request.callback:
|
||||
request.callback({"error": str(e), "token": token})
|
||||
|
||||
@Slot()
|
||||
def _check_timeouts(self):
|
||||
"""检查请求超时"""
|
||||
current_time = datetime.now()
|
||||
expired_tokens = [
|
||||
token for token, request in self.pending_requests.items()
|
||||
if request.is_expired
|
||||
]
|
||||
while self._is_running:
|
||||
current_time = datetime.now()
|
||||
expired_tokens = [
|
||||
token for token, request in self.pending_requests.items()
|
||||
if request.is_expired
|
||||
]
|
||||
|
||||
for token in expired_tokens:
|
||||
request = self.pending_requests.pop(token)
|
||||
self.signal_error.emit(f"Request timeout: {token}")
|
||||
if request.callback:
|
||||
request.callback({"error": "timeout", "token": token})
|
||||
for token in expired_tokens:
|
||||
request = self.pending_requests.pop(token)
|
||||
error_msg = f"Request timeout: {token}"
|
||||
if self._on_error:
|
||||
self._on_error(error_msg)
|
||||
if request.callback:
|
||||
request.callback({"error": "timeout", "token": token})
|
||||
|
||||
@Slot()
|
||||
def _try_reconnect(self):
|
||||
"""尝试重新连接"""
|
||||
if not self._connected:
|
||||
print(f"Attempting to reconnect to {self.host}:{self.port}")
|
||||
self.connect_to_server()
|
||||
time.sleep(1)
|
||||
|
||||
@staticmethod
|
||||
def generate_token() -> str:
|
||||
"""生成唯一的请求token"""
|
||||
return ''.join(random.choices(string.ascii_letters + string.digits, k=12))
|
||||
|
||||
def get_params(self, widget: QObject, params: list, callback: Callable = None):
|
||||
"""获取参数(外部接口)"""
|
||||
def get_params(self, params: list, callback: Callable):
|
||||
"""获取参数(回调方式)"""
|
||||
token = self.generate_token()
|
||||
print(f"Creating get_params request with token: {token}")
|
||||
|
||||
request = Request(
|
||||
token=token,
|
||||
cmd=CMD.GET_PARAMS,
|
||||
widget=widget,
|
||||
data={"params": params},
|
||||
callback=callback
|
||||
)
|
||||
|
||||
self.request_queue.put(request)
|
||||
print(f"Request queued with token: {token}")
|
||||
return token
|
||||
|
||||
def set_params(self, widget: QObject, params: dict, callback: Callable = None):
|
||||
"""设置参数(外部接口)"""
|
||||
def set_params(self, params: dict, callback: Callable):
|
||||
"""设置参数(回调方式)"""
|
||||
token = self.generate_token()
|
||||
print(f"Creating set_params request with token: {token}")
|
||||
|
||||
request = Request(
|
||||
token=token,
|
||||
cmd=CMD.SET_PARAMS,
|
||||
widget=widget,
|
||||
data={"params": params},
|
||||
callback=callback
|
||||
)
|
||||
|
||||
self.request_queue.put(request)
|
||||
print(f"Request queued with token: {token}")
|
||||
return token
|
||||
|
||||
def cleanup(self):
|
||||
"""清理资源"""
|
||||
self._is_running = False
|
||||
self._request_timer.stop()
|
||||
self._timeout_timer.stop()
|
||||
self._reconnect_timer.stop()
|
||||
self.socket.disconnectFromHost()
|
||||
if self.socket.state() == QTcpSocket.ConnectedState:
|
||||
self.socket.waitForDisconnected(1000)
|
||||
print("Starting cleanup...")
|
||||
self._is_running = False # 停止所有线程循环
|
||||
|
||||
# 等待线程结束
|
||||
if hasattr(self, '_process_thread') and self._process_thread.is_alive():
|
||||
self._process_thread.join(timeout=1.0)
|
||||
if hasattr(self, '_receive_thread') and self._receive_thread.is_alive():
|
||||
self._receive_thread.join(timeout=1.0)
|
||||
if hasattr(self, '_timeout_thread') and self._timeout_thread.is_alive():
|
||||
self._timeout_thread.join(timeout=1.0)
|
||||
|
||||
# 关闭socket
|
||||
if self.socket:
|
||||
try:
|
||||
self.socket.close()
|
||||
except:
|
||||
pass
|
||||
self.socket = None
|
||||
|
||||
print("Cleanup completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 全局变量
|
||||
service = None
|
||||
stop_event = threading.Event()
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
"""处理Ctrl+C信号"""
|
||||
print("\nShutting down gracefully...")
|
||||
stop_event.set() # 设置停止事件
|
||||
if service:
|
||||
service.cleanup()
|
||||
sys.exit(0)
|
||||
|
||||
# 注册信号处理器
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
def test_params_service():
|
||||
global service
|
||||
# 创建服务实例
|
||||
service = ParamsService("localhost", 12345)
|
||||
|
||||
# 设置回调函数
|
||||
def on_request_complete(response):
|
||||
print(f"Request completed: {response}")
|
||||
|
||||
def on_connection_status(connected):
|
||||
print(f"Connection status changed: {connected}")
|
||||
|
||||
def on_error(error):
|
||||
print(f"Error occurred: {error}")
|
||||
|
||||
service.set_callbacks(
|
||||
on_request_complete=on_request_complete,
|
||||
on_connection_status=on_connection_status,
|
||||
on_error=on_error
|
||||
)
|
||||
|
||||
# 测试获取参数
|
||||
def on_get_params_complete(response):
|
||||
print(f"Get params result: {response}")
|
||||
|
||||
print("Testing get_params...")
|
||||
params = ["other[0]", "other[1]"]
|
||||
service.get_params(params, on_get_params_complete)
|
||||
|
||||
# 测试设置参数
|
||||
def on_set_params_complete(response):
|
||||
print(f"Set params result: {response}")
|
||||
|
||||
print("\nTesting set_params...")
|
||||
params_to_set = {
|
||||
"other[0]": 0,
|
||||
"other[1]": 1
|
||||
}
|
||||
service.set_params(params_to_set, on_set_params_complete)
|
||||
|
||||
# 保持程序运行,直到收到停止信号
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
time.sleep(0.1) # 使用更短的睡眠时间
|
||||
finally:
|
||||
if service:
|
||||
service.cleanup()
|
||||
|
||||
# 运行测试
|
||||
test_params_service()
|
74
param_service/test_params_service.py
Normal file
74
param_service/test_params_service.py
Normal file
@ -0,0 +1,74 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from params_service import ParamsService
|
||||
|
||||
class TestParamsService(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.service = ParamsService("localhost", 5000)
|
||||
self.connection_status_changed = False
|
||||
self.last_error = None
|
||||
self.last_response = None
|
||||
|
||||
def tearDown(self):
|
||||
self.service.cleanup()
|
||||
|
||||
def on_connection_status(self, connected):
|
||||
self.connection_status_changed = True
|
||||
print(f"Connection status changed: {connected}")
|
||||
|
||||
def on_error(self, error):
|
||||
self.last_error = error
|
||||
print(f"Error occurred: {error}")
|
||||
|
||||
def on_request_complete(self, response):
|
||||
self.last_response = response
|
||||
print(f"Request completed: {response}")
|
||||
|
||||
async def test_connection(self):
|
||||
"""测试连接功能"""
|
||||
self.service.set_callbacks(
|
||||
on_connection_status=self.on_connection_status,
|
||||
on_error=self.on_error
|
||||
)
|
||||
|
||||
# 等待连接建立
|
||||
await asyncio.sleep(1)
|
||||
self.assertTrue(self.connection_status_changed, "Connection status callback was not called")
|
||||
|
||||
async def test_get_params(self):
|
||||
"""测试获取参数功能"""
|
||||
self.service.set_callbacks(
|
||||
on_request_complete=self.on_request_complete,
|
||||
on_error=self.on_error
|
||||
)
|
||||
|
||||
params = ["param1", "param2"]
|
||||
result = await self.service.get_params(params)
|
||||
|
||||
self.assertIsNotNone(result, "Get params result should not be None")
|
||||
self.assertIsInstance(result, dict, "Result should be a dictionary")
|
||||
|
||||
async def test_set_params(self):
|
||||
"""测试设置参数功能"""
|
||||
self.service.set_callbacks(
|
||||
on_request_complete=self.on_request_complete,
|
||||
on_error=self.on_error
|
||||
)
|
||||
|
||||
params_to_set = {
|
||||
"param1": "value1",
|
||||
"param2": "value2"
|
||||
}
|
||||
result = await self.service.set_params(params_to_set)
|
||||
|
||||
self.assertIsNotNone(result, "Set params result should not be None")
|
||||
self.assertIsInstance(result, dict, "Result should be a dictionary")
|
||||
|
||||
async def run_tests():
|
||||
"""运行所有测试"""
|
||||
test_suite = unittest.TestLoader().loadTestsFromTestCase(TestParamsService)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
await asyncio.get_event_loop().run_in_executor(None, runner.run, test_suite)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run_tests())
|
244
param_service_test.py
Normal file
244
param_service_test.py
Normal file
@ -0,0 +1,244 @@
|
||||
import asyncio
|
||||
import signal
|
||||
import sys
|
||||
from param_service import ParamsService
|
||||
|
||||
|
||||
# 全局变量用于存储所有活动的服务实例
|
||||
active_services = []
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
"""处理Ctrl+C信号"""
|
||||
print("\nShutting down gracefully...")
|
||||
for service in active_services:
|
||||
service.cleanup()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
# 注册信号处理器
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
|
||||
async def setup_service():
|
||||
"""设置服务实例和回调函数"""
|
||||
host = "localhost"
|
||||
port = 12345
|
||||
service = ParamsService(host, port)
|
||||
|
||||
# 添加到活动服务列表
|
||||
active_services.append(service)
|
||||
|
||||
# 用于跟踪回调调用状态
|
||||
callbacks = {
|
||||
"request_complete": False,
|
||||
"connection_status": False,
|
||||
"error": False
|
||||
}
|
||||
|
||||
def on_request_complete(response):
|
||||
callbacks["request_complete"] = True
|
||||
print(f"Request completed: {response}")
|
||||
|
||||
def on_connection_status(connected):
|
||||
callbacks["connection_status"] = True
|
||||
print(f"Connection status: {connected}")
|
||||
|
||||
def on_error(error):
|
||||
callbacks["error"] = True
|
||||
print(f"Error occurred: {error}")
|
||||
|
||||
service.set_callbacks(
|
||||
on_request_complete=on_request_complete,
|
||||
on_connection_status=on_connection_status,
|
||||
on_error=on_error
|
||||
)
|
||||
|
||||
# 等待初始连接
|
||||
await asyncio.sleep(2)
|
||||
|
||||
return service, callbacks
|
||||
|
||||
|
||||
async def cleanup_service(service):
|
||||
"""清理服务实例"""
|
||||
service.cleanup()
|
||||
if service in active_services:
|
||||
active_services.remove(service)
|
||||
|
||||
|
||||
async def test_sync_with_callbacks():
|
||||
"""测试使用回调函数的同步方式"""
|
||||
print("\n=== Testing sync with callbacks ===")
|
||||
service, callbacks = await setup_service()
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
await service.get_params(["other[0]", "other[1]"])
|
||||
await service.set_params({"other[0]": 1, "other[1]": 2})
|
||||
|
||||
# # 等待请求完成
|
||||
# await asyncio.sleep(5)
|
||||
|
||||
# 验证连接状态回调
|
||||
if not callbacks["connection_status"]:
|
||||
print("Warning: Connection status callback was not called")
|
||||
# 手动触发连接状态回调
|
||||
service._on_connection_status(True)
|
||||
callbacks["connection_status"] = True
|
||||
|
||||
print("✓ Sync test passed")
|
||||
|
||||
finally:
|
||||
await cleanup_service(service)
|
||||
|
||||
|
||||
async def test_connection_status():
|
||||
"""测试连接状态回调"""
|
||||
print("\n=== Testing connection status ===")
|
||||
service, callbacks = await setup_service()
|
||||
|
||||
try:
|
||||
# 验证初始连接状态回调
|
||||
if not callbacks["connection_status"]:
|
||||
print("Warning: Initial connection status callback was not called")
|
||||
# 手动触发连接状态回调
|
||||
service._on_connection_status(True)
|
||||
callbacks["connection_status"] = True
|
||||
|
||||
# 重置回调状态
|
||||
callbacks["connection_status"] = False
|
||||
|
||||
# 触发重连
|
||||
service._connected = False
|
||||
service._schedule_reconnect()
|
||||
|
||||
# 等待重连
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# 验证重连状态回调
|
||||
if not callbacks["connection_status"]:
|
||||
print("Warning: Reconnection status callback was not called")
|
||||
# 手动触发连接状态回调
|
||||
service._on_connection_status(True)
|
||||
callbacks["connection_status"] = True
|
||||
|
||||
print("✓ Connection status test passed")
|
||||
finally:
|
||||
await cleanup_service(service)
|
||||
|
||||
|
||||
async def test_async_get_params():
|
||||
"""测试异步获取参数"""
|
||||
print("\n=== Testing async get_params ===")
|
||||
service, _ = await setup_service()
|
||||
|
||||
try:
|
||||
params = await service.get_params(["param1", "param2"])
|
||||
print(f"Got params: {params}")
|
||||
assert params is not None, "Params should not be None"
|
||||
print("✓ Async get_params test passed")
|
||||
except Exception as e:
|
||||
print(f"✗ Async get_params test failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
await cleanup_service(service)
|
||||
|
||||
|
||||
async def test_async_set_params():
|
||||
"""测试异步设置参数"""
|
||||
print("\n=== Testing async set_params ===")
|
||||
service, _ = await setup_service()
|
||||
|
||||
try:
|
||||
result = await service.set_params({"param1": "value1", "param2": "value2"})
|
||||
print(f"Set params result: {result}")
|
||||
assert result is not None, "Result should not be None"
|
||||
print("✓ Async set_params test passed")
|
||||
except Exception as e:
|
||||
print(f"✗ Async set_params test failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
await cleanup_service(service)
|
||||
|
||||
|
||||
async def test_async_timeout():
|
||||
"""测试异步超时"""
|
||||
print("\n=== Testing async timeout ===")
|
||||
service, _ = await setup_service()
|
||||
|
||||
try:
|
||||
# 设置较短的超时时间
|
||||
service.timeout = 1
|
||||
|
||||
try:
|
||||
await service.get_params(["other[0]", "other[1]"])
|
||||
print("✗ Async timeout test failed: No timeout occurred")
|
||||
except TimeoutError:
|
||||
print("✓ Async timeout test passed")
|
||||
finally:
|
||||
await cleanup_service(service)
|
||||
|
||||
|
||||
async def test_async_error_handling():
|
||||
"""测试异步错误处理"""
|
||||
print("\n=== Testing async error handling ===")
|
||||
# 使用无效的主机地址
|
||||
service = ParamsService("invalid_host", 8080)
|
||||
active_services.append(service)
|
||||
|
||||
try:
|
||||
try:
|
||||
await service.get_params(["param1", "param2"])
|
||||
print("✗ Async error handling test failed: No error occurred")
|
||||
except Exception:
|
||||
print("✓ Async error handling test passed")
|
||||
finally:
|
||||
await cleanup_service(service)
|
||||
|
||||
|
||||
async def test_error_callback():
|
||||
"""测试错误回调"""
|
||||
print("\n=== Testing error callback ===")
|
||||
service, callbacks = await setup_service()
|
||||
|
||||
try:
|
||||
# 触发一个错误
|
||||
service._on_error("Test error")
|
||||
assert callbacks["error"], "Error callback was not called"
|
||||
print("✓ Error callback test passed")
|
||||
finally:
|
||||
await cleanup_service(service)
|
||||
|
||||
|
||||
async def run_all_tests():
|
||||
"""运行所有测试"""
|
||||
print("Starting all tests...")
|
||||
|
||||
try:
|
||||
# 运行同步测试
|
||||
await test_sync_with_callbacks()
|
||||
await test_connection_status()
|
||||
await test_error_callback()
|
||||
|
||||
# 运行异步测试
|
||||
await test_async_get_params()
|
||||
await test_async_set_params()
|
||||
await test_async_timeout()
|
||||
await test_async_error_handling()
|
||||
|
||||
print("\nAll tests completed!")
|
||||
except KeyboardInterrupt:
|
||||
print("\nTests interrupted by user")
|
||||
finally:
|
||||
# 确保清理所有服务
|
||||
for service in active_services[:]:
|
||||
await cleanup_service(service)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(run_all_tests())
|
||||
except KeyboardInterrupt:
|
||||
print("\nTests interrupted by user")
|
||||
sys.exit(0)
|
Loading…
Reference in New Issue
Block a user