brisonus_app_eq/persistence/data_store.py

439 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import csv
import os
import json
from typing import Dict, List, Any, Optional
from datetime import datetime
from persistence.models import *
from component.widget_log.log_handler import logger
class DataStore:
def __init__(self, storage_dir: str = "data/projects"):
self.storage_dir = storage_dir
self.current_project: Optional[str] = None
self.current_param: Optional[str] = None
self.current_paramter_name: Optional[str] = None
self._ensure_storage_dir()
def _ensure_storage_dir(self):
"""确保存储目录存在"""
if not os.path.exists(self.storage_dir):
os.makedirs(self.storage_dir)
# 确保参数数据目录存在
params_dir = os.path.join(self.storage_dir, "params")
if not os.path.exists(params_dir):
os.makedirs(params_dir)
def _get_project_path(self, project_name: str) -> str:
"""获取项目元数据文件路径"""
return os.path.join(self.storage_dir, f"{project_name}.json")
def _get_param_path(self, project_name: str, param_name: str) -> str:
"""获取参数数据文件路径"""
params_dir = os.path.join(self.storage_dir, "params")
return os.path.join(params_dir, f"{project_name}_{param_name}.csv")
def save_project(self, project_name: str, description: str = "") -> bool:
"""创建或更新项目元数据"""
try:
now = datetime.now().isoformat()
project_data = ProjectData(
name=project_name,
created_at=now if not self._project_exists(project_name) else self._get_project_created_time(project_name),
last_modified=now,
description=description,
params={}
)
# 保存项目元数据
self._save_project_metadata(project_name, project_data)
self.current_project = project_name
logger.info(f"项目 {project_name} 保存成功")
return True
except Exception as e:
logger.error(f"保存项目失败: {e}")
return False
def add_param_to_project(self, project_name: str, param_name: str,
channel_data: Dict[int, Dict], description: str = "") -> bool:
"""向项目添加参数配置"""
try:
# 加载项目元数据
project_data = self.load_project(project_name)
if not project_data:
raise ValueError(f"Project {project_name} not found")
# 创建简化的参数配置(只包含描述信息)
param_config = ParamConfig(
name=param_name,
created_at=datetime.now().isoformat(),
description=description,
channels={} # 不再存储通道配置
)
# 更新项目元数据
project_data.params[param_name] = param_config
project_data.last_modified = datetime.now().isoformat()
self._save_project_metadata(project_name, project_data)
# 保存参数数据到CSV文件
self._save_param_to_csv(project_name, param_name, channel_data)
logger.info(f"参数 {param_name} 添加到项目 {project_name} 成功")
return True
except Exception as e:
logger.error(f"添加参数失败: {e}")
return False
def _save_param_to_csv(self, project_name: str, param_name: str, channel_data: Dict[int, Dict]):
"""将参数数据保存为CSV格式按照指定的顺序排列"""
csv_path = self._get_param_path(project_name, param_name)
# 准备参数数据列表
param_data = []
# 添加基本参数
param_data.append({
'parameter': 'dataset.audio_mode',
'value': '0'
})
param_data.append({
'parameter': 'dataset.send_action',
'value': '0'
})
# 添加混音参数
for i in range(6):
ch_value = str(i if i in channel_data else i)
param_data.append({
'parameter': f'dataset.tuning_parameters.mix_parameters[{i}].ch_n',
'value': ch_value
})
left_value = str(channel_data.get(i, {}).get('mix_left_data', 0.0))
param_data.append({
'parameter': f'dataset.tuning_parameters.mix_parameters[{i}].mix_left_data',
'value': left_value
})
right_value = str(channel_data.get(i, {}).get('mix_right_data', 0.0))
param_data.append({
'parameter': f'dataset.tuning_parameters.mix_parameters[{i}].mix_right_data',
'value': right_value
})
# 添加EQ参数
for i in range(120):
# 确定该滤波器属于哪个通道
channel_id = i // 20 # 假设每个通道最多20个滤波器
filter_idx = i % 20
filter_data = {}
if channel_id in channel_data and 'filters' in channel_data[channel_id]:
filters = channel_data[channel_id]['filters']
if filter_idx < len(filters):
filter_data = filters[filter_idx]
# 中心频率
param_data.append({
'parameter': f'dataset.tuning_parameters.eq_parameters[{i}].fc',
'value': str(filter_data.get('fc', 0.0))
})
# Q值
param_data.append({
'parameter': f'dataset.tuning_parameters.eq_parameters[{i}].q',
'value': str(filter_data.get('q', 0.0))
})
# 增益
param_data.append({
'parameter': f'dataset.tuning_parameters.eq_parameters[{i}].gain',
'value': str(filter_data.get('gain', 0.0))
})
# 斜率
param_data.append({
'parameter': f'dataset.tuning_parameters.eq_parameters[{i}].slope',
'value': str(filter_data.get('slope', 0))
})
# 滤波器类型
param_data.append({
'parameter': f'dataset.tuning_parameters.eq_parameters[{i}].filterType',
'value': str(filter_data.get('filterType', 0))
})
# 添加延迟参数
for i in range(6):
ch_value = str(i if i in channel_data else i)
param_data.append({
'parameter': f'dataset.tuning_parameters.delay_parameters[{i}].ch_n',
'value': ch_value
})
delay_value = str(channel_data.get(i, {}).get('delay_data', 0.0))
param_data.append({
'parameter': f'dataset.tuning_parameters.delay_parameters[{i}].delay_data',
'value': delay_value
})
# 添加音量参数
for i in range(6):
ch_value = str(i if i in channel_data else i)
param_data.append({
'parameter': f'dataset.tuning_parameters.volume_parameters[{i}].ch_n',
'value': ch_value
})
vol_value = str(channel_data.get(i, {}).get('vol_data', 0.0))
param_data.append({
'parameter': f'dataset.tuning_parameters.volume_parameters[{i}].vol_data',
'value': vol_value
})
# 写入CSV文件
with open(csv_path, 'w', newline='') as csvfile:
fieldnames = ['parameter', 'value']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for param in param_data:
writer.writerow({
'parameter': param['parameter'],
'value': param['value']
})
def _get_param_structure(self):
"""解析struct_params.txt获取参数结构"""
# 这里可以实现解析struct_params.txt的逻辑
# 简化起见,我们直接使用硬编码的结构
return {}
def _convert_to_channel_config(self, channel_data: Dict[int, Dict]) -> Dict[int, ChannelConfig]:
"""转换通道数据为ChannelConfig格式"""
# 由于JSON不再存储通道配置此方法可以简化
return {}
def load_project(self, project_name: str) -> Optional[ProjectData]:
"""加载项目元数据"""
try:
file_path = self._get_project_path(project_name)
if not os.path.exists(file_path):
return None
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
project_data = ProjectData(**data)
return project_data
except Exception as e:
logger.error(f"加载项目失败: {e}")
return None
def load_param_data(self, project_name: str, param_name: str) -> Dict:
"""加载参数数据"""
try:
csv_path = self._get_param_path(project_name, param_name)
if not os.path.exists(csv_path):
return {}
param_data = {}
with open(csv_path, 'r', newline='') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
param_data[row['parameter']] = row['value']
# 转换为通道数据格式
channel_data = self._convert_csv_to_channel_data(param_data)
return channel_data
except Exception as e:
logger.error(f"加载参数数据失败: {e}")
return {}
def _convert_csv_to_channel_data(self, param_data: Dict) -> Dict[int, Dict]:
"""将CSV格式的参数数据转换为通道数据格式"""
channel_data = {}
# 处理混音参数
# to do: 后续处理这个6
for i in range(6): # 假设最多6个通道
ch_key = f'dataset.tuning_parameters.mix_parameters[{i}].ch_n'
if ch_key in param_data:
# 将字符串先转为浮点数,再转为整数,避免小数点导致的转换错误
channel_id = int(float(param_data[ch_key]))
if channel_id not in channel_data:
channel_data[channel_id] = {'filters': []}
left_key = f'dataset.tuning_parameters.mix_parameters[{i}].mix_left_data'
if left_key in param_data:
channel_data[channel_id]['mix_left_data'] = float(param_data[left_key])
right_key = f'dataset.tuning_parameters.mix_parameters[{i}].mix_right_data'
if right_key in param_data:
channel_data[channel_id]['mix_right_data'] = float(param_data[right_key])
for i in range(6):
ch_key = f'dataset.tuning_parameters.delay_parameters[{i}].ch_n'
if ch_key in param_data:
# 同样先转为浮点数再转为整数
channel_id = int(float(param_data[ch_key]))
if channel_id not in channel_data:
channel_data[channel_id] = {'filters': []}
delay_key = f'dataset.tuning_parameters.delay_parameters[{i}].delay_data'
if delay_key in param_data:
channel_data[channel_id]['delay_data'] = float(param_data[delay_key])
for i in range(6):
ch_key = f'dataset.tuning_parameters.volume_parameters[{i}].ch_n'
if ch_key in param_data:
# 同样先转为浮点数再转为整数
channel_id = int(float(param_data[ch_key]))
if channel_id not in channel_data:
channel_data[channel_id] = {'filters': []}
vol_key = f'dataset.tuning_parameters.volume_parameters[{i}].vol_data'
if vol_key in param_data:
channel_data[channel_id]['vol_data'] = float(param_data[vol_key])
for i in range(120): # 最多120个滤波器
fc_key = f'dataset.tuning_parameters.eq_parameters[{i}].fc'
if fc_key in param_data:
# 确定该滤波器属于哪个通道
channel_id = i // 20 # 假设每个通道最多20个滤波器
filter_idx = i % 20
if channel_id not in channel_data:
channel_data[channel_id] = {'filters': []}
# 确保filters列表有足够的元素
while len(channel_data[channel_id]['filters']) <= filter_idx:
channel_data[channel_id]['filters'].append({})
# 设置滤波器参数
filter_data = channel_data[channel_id]['filters'][filter_idx]
filter_data['fc'] = float(param_data[fc_key])
q_key = f'dataset.tuning_parameters.eq_parameters[{i}].q'
if q_key in param_data:
filter_data['q'] = float(param_data[q_key])
gain_key = f'dataset.tuning_parameters.eq_parameters[{i}].gain'
if gain_key in param_data:
filter_data['gain'] = float(param_data[gain_key])
slope_key = f'dataset.tuning_parameters.eq_parameters[{i}].slope'
if slope_key in param_data:
# 先转为浮点数再转为整数
filter_data['slope'] = int(float(param_data[slope_key]))
filter_type_key = f'dataset.tuning_parameters.eq_parameters[{i}].filterType'
if filter_type_key in param_data:
# 先转为浮点数再转为整数
filter_data['filterType'] = int(float(param_data[filter_type_key]))
# 添加enable和name参数
enable_key = f'dataset.tuning_parameters.eq_parameters[{i}].enable'
if enable_key in param_data:
filter_data['enable'] = param_data[enable_key].lower() == 'true'
name_key = f'dataset.tuning_parameters.eq_parameters[{i}].name'
if name_key in param_data:
filter_data['name'] = param_data[name_key]
return channel_data
def list_projects(self) -> List[str]:
"""列出所有项目"""
try:
projects = []
for file in os.listdir(self.storage_dir):
if file.endswith('.json'):
projects.append(file[:-5])
return projects
except Exception as e:
logger.error(f"列出项目失败: {e}")
return []
def list_params(self, project_name: str) -> List[str]:
"""列出项目的所有参数"""
try:
project_data = self.load_project(project_name)
if project_data:
return list(project_data.params.keys())
return []
except Exception as e:
logger.error(f"列出参数失败: {e}")
return []
def delete_project(self, project_name: str) -> bool:
"""删除项目"""
try:
# 删除项目元数据文件
file_path = self._get_project_path(project_name)
if os.path.exists(file_path):
os.remove(file_path)
# 删除项目相关的参数文件
params_dir = os.path.join(self.storage_dir, "params")
for file in os.listdir(params_dir):
if file.startswith(f"{project_name}_") and file.endswith('.csv'):
os.remove(os.path.join(params_dir, file))
if self.current_project == project_name:
self.current_project = None
self.current_param = None
logger.info(f"项目 {project_name} 删除成功")
return True
except Exception as e:
logger.error(f"删除项目失败: {e}")
return False
def delete_param(self, project_name: str, param_name: str) -> bool:
"""删除参数"""
try:
# 更新项目元数据
project_data = self.load_project(project_name)
if project_data and param_name in project_data.params:
del project_data.params[param_name]
project_data.last_modified = datetime.now().isoformat()
self._save_project_metadata(project_name, project_data)
# 删除参数文件
param_path = self._get_param_path(project_name, param_name)
if os.path.exists(param_path):
os.remove(param_path)
if self.current_project == project_name and self.current_param == param_name:
self.current_param = None
logger.info(f"参数 {param_name} 删除成功")
return True
except Exception as e:
logger.error(f"删除参数失败: {e}")
return False
def _project_exists(self, project_name: str) -> bool:
"""检查项目是否存在"""
return os.path.exists(self._get_project_path(project_name))
def _get_project_created_time(self, project_name: str) -> str:
"""获取项目创建时间"""
if self._project_exists(project_name):
data = self.load_project(project_name)
return data.created_at if data else datetime.now().isoformat()
return datetime.now().isoformat()
def _save_project_metadata(self, project_name: str, project_data: ProjectData):
"""保存项目元数据到文件"""
file_path = self._get_project_path(project_name)
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(asdict(project_data), f, indent=2, ensure_ascii=False)