|
@@ -6,9 +6,12 @@ from tqdm import tqdm
|
|
import sqlite3
|
|
import sqlite3
|
|
from datetime import datetime
|
|
from datetime import datetime
|
|
import logging
|
|
import logging
|
|
|
|
+from logging.handlers import RotatingFileHandler
|
|
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
|
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
|
import asyncio
|
|
import asyncio
|
|
import yaml
|
|
import yaml
|
|
|
|
+import threading
|
|
|
|
+from collections import deque
|
|
|
|
|
|
# 配置管理
|
|
# 配置管理
|
|
class Config:
|
|
class Config:
|
|
@@ -16,12 +19,32 @@ class Config:
|
|
self.config_path = config_path
|
|
self.config_path = config_path
|
|
self.config = self.load_config()
|
|
self.config = self.load_config()
|
|
|
|
|
|
|
|
+ # 验证配置
|
|
|
|
+ self.validate_config()
|
|
|
|
+
|
|
# 设置日志
|
|
# 设置日志
|
|
self.setup_logging()
|
|
self.setup_logging()
|
|
|
|
|
|
# 初始化OpenAI客户端
|
|
# 初始化OpenAI客户端
|
|
self.setup_openai()
|
|
self.setup_openai()
|
|
|
|
|
|
|
|
+ def validate_config(self):
|
|
|
|
+ """验证配置项"""
|
|
|
|
+ required_fields = {
|
|
|
|
+ 'logging': ['level', 'format', 'file'],
|
|
|
|
+ 'openai': ['base_url', 'api_key', 'model_name', 'max_retries', 'retry_delay', 'timeout', 'max_concurrent_requests'],
|
|
|
|
+ 'translation': ['min_line_count', 'max_line_count', 'initial_line_count', 'error_threshold', 'success_threshold', 'error_cooldown', 'cache_size'],
|
|
|
|
+ 'database': ['path', 'pool_size'],
|
|
|
|
+ 'paths': ['input_dir', 'output_dir']
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for section, fields in required_fields.items():
|
|
|
|
+ if section not in self.config:
|
|
|
|
+ raise ValueError(f"缺少配置节: {section}")
|
|
|
|
+ for field in fields:
|
|
|
|
+ if field not in self.config[section]:
|
|
|
|
+ raise ValueError(f"缺少配置项: {section}.{field}")
|
|
|
|
+
|
|
def load_config(self):
|
|
def load_config(self):
|
|
"""加载配置文件"""
|
|
"""加载配置文件"""
|
|
if not os.path.exists(self.config_path):
|
|
if not os.path.exists(self.config_path):
|
|
@@ -72,14 +95,30 @@ class Config:
|
|
|
|
|
|
def setup_logging(self):
|
|
def setup_logging(self):
|
|
"""设置日志"""
|
|
"""设置日志"""
|
|
- logging.basicConfig(
|
|
|
|
- level=getattr(logging, self.config['logging']['level']),
|
|
|
|
- format=self.config['logging']['format'],
|
|
|
|
- handlers=[
|
|
|
|
- logging.FileHandler(self.config['logging']['file']),
|
|
|
|
- logging.StreamHandler()
|
|
|
|
- ]
|
|
|
|
|
|
+ log_file = self.config['logging']['file']
|
|
|
|
+ log_dir = os.path.dirname(log_file)
|
|
|
|
+ if log_dir and not os.path.exists(log_dir):
|
|
|
|
+ os.makedirs(log_dir)
|
|
|
|
+
|
|
|
|
+ # 创建日志处理器
|
|
|
|
+ file_handler = RotatingFileHandler(
|
|
|
|
+ log_file,
|
|
|
|
+ maxBytes=10*1024*1024, # 10MB
|
|
|
|
+ backupCount=5,
|
|
|
|
+ encoding='utf-8'
|
|
)
|
|
)
|
|
|
|
+ console_handler = logging.StreamHandler()
|
|
|
|
+
|
|
|
|
+ # 设置日志格式
|
|
|
|
+ formatter = logging.Formatter(self.config['logging']['format'])
|
|
|
|
+ file_handler.setFormatter(formatter)
|
|
|
|
+ console_handler.setFormatter(formatter)
|
|
|
|
+
|
|
|
|
+ # 配置根日志记录器
|
|
|
|
+ root_logger = logging.getLogger()
|
|
|
|
+ root_logger.setLevel(getattr(logging, self.config['logging']['level']))
|
|
|
|
+ root_logger.addHandler(file_handler)
|
|
|
|
+ root_logger.addHandler(console_handler)
|
|
|
|
|
|
def setup_openai(self):
|
|
def setup_openai(self):
|
|
"""设置OpenAI客户端"""
|
|
"""设置OpenAI客户端"""
|
|
@@ -176,6 +215,8 @@ class DatabaseManager:
|
|
def __init__(self):
|
|
def __init__(self):
|
|
self.db_path = config.get('database', 'path')
|
|
self.db_path = config.get('database', 'path')
|
|
self.conn = None
|
|
self.conn = None
|
|
|
|
+ self.batch_size = 100 # 批量更新的大小
|
|
|
|
+ self.pending_updates = [] # 待更新的操作
|
|
self.init_db()
|
|
self.init_db()
|
|
|
|
|
|
def get_connection(self):
|
|
def get_connection(self):
|
|
@@ -183,14 +224,91 @@ class DatabaseManager:
|
|
if self.conn is None:
|
|
if self.conn is None:
|
|
self.conn = sqlite3.connect(self.db_path)
|
|
self.conn = sqlite3.connect(self.db_path)
|
|
self.conn.row_factory = sqlite3.Row
|
|
self.conn.row_factory = sqlite3.Row
|
|
|
|
+ # 启用外键约束
|
|
|
|
+ self.conn.execute("PRAGMA foreign_keys = ON")
|
|
|
|
+ # 设置WAL模式提高并发性能
|
|
|
|
+ self.conn.execute("PRAGMA journal_mode = WAL")
|
|
return self.conn
|
|
return self.conn
|
|
|
|
|
|
def close(self):
|
|
def close(self):
|
|
"""关闭数据库连接"""
|
|
"""关闭数据库连接"""
|
|
if self.conn:
|
|
if self.conn:
|
|
|
|
+ # 提交所有待处理的更新
|
|
|
|
+ self.flush_updates()
|
|
self.conn.close()
|
|
self.conn.close()
|
|
self.conn = None
|
|
self.conn = None
|
|
|
|
|
|
|
|
+ def flush_updates(self):
|
|
|
|
+ """提交所有待处理的更新"""
|
|
|
|
+ if not self.pending_updates:
|
|
|
|
+ return
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ self.begin_transaction()
|
|
|
|
+ for update in self.pending_updates:
|
|
|
|
+ update()
|
|
|
|
+ self.commit_transaction()
|
|
|
|
+ except Exception as e:
|
|
|
|
+ self.rollback_transaction()
|
|
|
|
+ logging.error(f"批量更新失败: {str(e)}")
|
|
|
|
+ raise
|
|
|
|
+ finally:
|
|
|
|
+ self.pending_updates = []
|
|
|
|
+
|
|
|
|
+ def add_update(self, update_func):
|
|
|
|
+ """添加待处理的更新操作"""
|
|
|
|
+ self.pending_updates.append(update_func)
|
|
|
|
+ if len(self.pending_updates) >= self.batch_size:
|
|
|
|
+ self.flush_updates()
|
|
|
|
+
|
|
|
|
+ def update_file_progress(self, file_path, total_lines, processed_lines, status):
|
|
|
|
+ """更新文件翻译进度"""
|
|
|
|
+ def update():
|
|
|
|
+ c = self.get_connection().cursor()
|
|
|
|
+ c.execute('''
|
|
|
|
+ INSERT OR REPLACE INTO file_progress
|
|
|
|
+ (file_path, total_lines, processed_lines, status, last_updated)
|
|
|
|
+ VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
|
|
|
|
+ ''', (file_path, total_lines, processed_lines, status))
|
|
|
|
+
|
|
|
|
+ self.add_update(update)
|
|
|
|
+
|
|
|
|
+ def update_line_progress(self, file_path, line_index, original_text, translated_text, status):
|
|
|
|
+ """更新行翻译进度"""
|
|
|
|
+ def update():
|
|
|
|
+ c = self.get_connection().cursor()
|
|
|
|
+ c.execute('''
|
|
|
|
+ INSERT OR REPLACE INTO line_progress
|
|
|
|
+ (file_path, line_index, original_text, translated_text, status, updated_at)
|
|
|
|
+ VALUES (?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
|
|
|
|
+ ''', (file_path, line_index, original_text, translated_text, status))
|
|
|
|
+
|
|
|
|
+ self.add_update(update)
|
|
|
|
+
|
|
|
|
+ def update_group_progress(self, file_path, group_index, original_text, translated_text, status):
|
|
|
|
+ """更新翻译组进度"""
|
|
|
|
+ def update():
|
|
|
|
+ c = self.get_connection().cursor()
|
|
|
|
+ c.execute('''
|
|
|
|
+ INSERT OR REPLACE INTO group_progress
|
|
|
|
+ (file_path, group_index, original_text, translated_text, status, version, updated_at)
|
|
|
|
+ VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
|
|
|
|
+ ''', (file_path, group_index, original_text, translated_text, status, VERSION))
|
|
|
|
+
|
|
|
|
+ self.add_update(update)
|
|
|
|
+
|
|
|
|
+ def log_error(self, file_path, line_index, error_type, error_message):
|
|
|
|
+ """记录错误"""
|
|
|
|
+ def update():
|
|
|
|
+ c = self.get_connection().cursor()
|
|
|
|
+ c.execute('''
|
|
|
|
+ INSERT INTO error_log
|
|
|
|
+ (file_path, line_index, error_type, error_message)
|
|
|
|
+ VALUES (?, ?, ?, ?)
|
|
|
|
+ ''', (file_path, line_index, error_type, error_message))
|
|
|
|
+
|
|
|
|
+ self.add_update(update)
|
|
|
|
+
|
|
def init_db(self):
|
|
def init_db(self):
|
|
"""初始化数据库"""
|
|
"""初始化数据库"""
|
|
conn = self.get_connection()
|
|
conn = self.get_connection()
|
|
@@ -276,16 +394,6 @@ class DatabaseManager:
|
|
c.execute('SELECT * FROM file_progress WHERE file_path = ?', (file_path,))
|
|
c.execute('SELECT * FROM file_progress WHERE file_path = ?', (file_path,))
|
|
return c.fetchone()
|
|
return c.fetchone()
|
|
|
|
|
|
- def update_file_progress(self, file_path, total_lines, processed_lines, status):
|
|
|
|
- """更新文件翻译进度"""
|
|
|
|
- c = self.get_connection().cursor()
|
|
|
|
- c.execute('''
|
|
|
|
- INSERT OR REPLACE INTO file_progress
|
|
|
|
- (file_path, total_lines, processed_lines, status, last_updated)
|
|
|
|
- VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
|
|
|
|
- ''', (file_path, total_lines, processed_lines, status))
|
|
|
|
- self.get_connection().commit()
|
|
|
|
-
|
|
|
|
def get_line_progress(self, file_path, line_index):
|
|
def get_line_progress(self, file_path, line_index):
|
|
"""获取行翻译进度"""
|
|
"""获取行翻译进度"""
|
|
c = self.get_connection().cursor()
|
|
c = self.get_connection().cursor()
|
|
@@ -295,36 +403,6 @@ class DatabaseManager:
|
|
''', (file_path, line_index))
|
|
''', (file_path, line_index))
|
|
return c.fetchone()
|
|
return c.fetchone()
|
|
|
|
|
|
- def update_line_progress(self, file_path, line_index, original_text, translated_text, status):
|
|
|
|
- """更新行翻译进度"""
|
|
|
|
- c = self.get_connection().cursor()
|
|
|
|
- c.execute('''
|
|
|
|
- INSERT OR REPLACE INTO line_progress
|
|
|
|
- (file_path, line_index, original_text, translated_text, status, updated_at)
|
|
|
|
- VALUES (?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
|
|
|
|
- ''', (file_path, line_index, original_text, translated_text, status))
|
|
|
|
- self.get_connection().commit()
|
|
|
|
-
|
|
|
|
- def log_error(self, file_path, line_index, error_type, error_message):
|
|
|
|
- """记录错误"""
|
|
|
|
- c = self.get_connection().cursor()
|
|
|
|
- c.execute('''
|
|
|
|
- INSERT INTO error_log
|
|
|
|
- (file_path, line_index, error_type, error_message)
|
|
|
|
- VALUES (?, ?, ?, ?)
|
|
|
|
- ''', (file_path, line_index, error_type, error_message))
|
|
|
|
- self.get_connection().commit()
|
|
|
|
-
|
|
|
|
- def update_group_progress(self, file_path, group_index, original_text, translated_text, status):
|
|
|
|
- """更新翻译组进度"""
|
|
|
|
- c = self.get_connection().cursor()
|
|
|
|
- c.execute('''
|
|
|
|
- INSERT OR REPLACE INTO group_progress
|
|
|
|
- (file_path, group_index, original_text, translated_text, status, version, updated_at)
|
|
|
|
- VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
|
|
|
|
- ''', (file_path, group_index, original_text, translated_text, status, VERSION))
|
|
|
|
- self.get_connection().commit()
|
|
|
|
-
|
|
|
|
def get_error_stats(self):
|
|
def get_error_stats(self):
|
|
"""获取错误统计信息"""
|
|
"""获取错误统计信息"""
|
|
c = self.get_connection().cursor()
|
|
c = self.get_connection().cursor()
|
|
@@ -371,9 +449,37 @@ def get_completed_groups(conn, file_path):
|
|
''', (file_path, VERSION))
|
|
''', (file_path, VERSION))
|
|
return c.fetchall()
|
|
return c.fetchall()
|
|
|
|
|
|
-# """ - 输出内容要求用代码块包裹起来
|
|
|
|
-# ,只在必要时提供相应的语言注释
|
|
|
|
-# """
|
|
|
|
|
|
+class TokenBucket:
|
|
|
|
+ """令牌桶限流器"""
|
|
|
|
+ def __init__(self, rate, capacity):
|
|
|
|
+ self.rate = rate # 令牌产生速率(每秒)
|
|
|
|
+ self.capacity = capacity # 桶容量
|
|
|
|
+ self.tokens = capacity # 当前令牌数
|
|
|
|
+ self.last_update = time.time()
|
|
|
|
+ self.lock = threading.Lock()
|
|
|
|
+
|
|
|
|
+ def get_token(self):
|
|
|
|
+ """获取一个令牌"""
|
|
|
|
+ with self.lock:
|
|
|
|
+ now = time.time()
|
|
|
|
+ # 计算新增的令牌
|
|
|
|
+ new_tokens = (now - self.last_update) * self.rate
|
|
|
|
+ self.tokens = min(self.capacity, self.tokens + new_tokens)
|
|
|
|
+ self.last_update = now
|
|
|
|
+
|
|
|
|
+ if self.tokens >= 1:
|
|
|
|
+ self.tokens -= 1
|
|
|
|
+ return True
|
|
|
|
+ return False
|
|
|
|
+
|
|
|
|
+ def wait_for_token(self):
|
|
|
|
+ """等待直到获得令牌"""
|
|
|
|
+ while not self.get_token():
|
|
|
|
+ time.sleep(0.1)
|
|
|
|
+
|
|
|
|
+# 创建全局的令牌桶实例
|
|
|
|
+token_bucket = TokenBucket(rate=2, capacity=10) # 每秒2个请求,最多10个并发
|
|
|
|
+
|
|
@retry(
|
|
@retry(
|
|
stop=stop_after_attempt(MODEL_CONFIG['max_retries']),
|
|
stop=stop_after_attempt(MODEL_CONFIG['max_retries']),
|
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
@@ -387,6 +493,9 @@ def translate_text(text):
|
|
return text
|
|
return text
|
|
|
|
|
|
try:
|
|
try:
|
|
|
|
+ # 等待获取令牌
|
|
|
|
+ token_bucket.wait_for_token()
|
|
|
|
+
|
|
messages = [
|
|
messages = [
|
|
{
|
|
{
|
|
"role": "system",
|
|
"role": "system",
|
|
@@ -440,6 +549,22 @@ def translate_text(text):
|
|
translation_stats.update_stats(text, "", False)
|
|
translation_stats.update_stats(text, "", False)
|
|
raise
|
|
raise
|
|
|
|
|
|
|
|
+def calculate_group_size(text_length):
|
|
|
|
+ """根据文本长度动态计算分组大小"""
|
|
|
|
+ if text_length < 1000:
|
|
|
|
+ return 4
|
|
|
|
+ elif text_length < 2000:
|
|
|
|
+ return 3
|
|
|
|
+ else:
|
|
|
|
+ return 2
|
|
|
|
+
|
|
|
|
+def resume_translation(file_path, db_manager):
|
|
|
|
+ """获取断点续传的起始位置"""
|
|
|
|
+ progress = db_manager.get_file_progress(file_path)
|
|
|
|
+ if progress and progress['status'] == 'interrupted':
|
|
|
|
+ return progress['processed_lines']
|
|
|
|
+ return 0
|
|
|
|
+
|
|
def process_html_file(file_path, conn):
|
|
def process_html_file(file_path, conn):
|
|
"""处理HTML文件"""
|
|
"""处理HTML文件"""
|
|
# 检查文件进度
|
|
# 检查文件进度
|
|
@@ -487,12 +612,8 @@ def process_html_file(file_path, conn):
|
|
else:
|
|
else:
|
|
logging.info("跳过空标题")
|
|
logging.info("跳过空标题")
|
|
|
|
|
|
- # 按行分割内容,保留所有HTML标签行,但只翻译包含 <p class 的行
|
|
|
|
- lines = []
|
|
|
|
- for line in body_content.split('\n'):
|
|
|
|
- line = line.strip()
|
|
|
|
- if line and line.startswith('<'):
|
|
|
|
- lines.append(line)
|
|
|
|
|
|
+ # 按行分割内容,保留所有非空行
|
|
|
|
+ lines = [line.strip() for line in body_content.split('\n') if line.strip()]
|
|
|
|
|
|
total_lines = len(lines)
|
|
total_lines = len(lines)
|
|
logging.info(f"文件 {file_path} 共有 {total_lines} 行需要处理")
|
|
logging.info(f"文件 {file_path} 共有 {total_lines} 行需要处理")
|
|
@@ -501,6 +622,9 @@ def process_html_file(file_path, conn):
|
|
completed_lines = get_completed_groups(conn, file_path)
|
|
completed_lines = get_completed_groups(conn, file_path)
|
|
completed_indices = {line[0] for line in completed_lines}
|
|
completed_indices = {line[0] for line in completed_lines}
|
|
|
|
|
|
|
|
+ # 获取断点续传位置
|
|
|
|
+ start_line = resume_translation(file_path, db_manager)
|
|
|
|
+
|
|
# 计算已处理的进度
|
|
# 计算已处理的进度
|
|
if progress:
|
|
if progress:
|
|
progress_percentage = round(progress['processed_lines']*100/progress['total_lines'], 2)
|
|
progress_percentage = round(progress['processed_lines']*100/progress['total_lines'], 2)
|
|
@@ -509,9 +633,11 @@ def process_html_file(file_path, conn):
|
|
# 逐行处理内容
|
|
# 逐行处理内容
|
|
translated_lines = []
|
|
translated_lines = []
|
|
try:
|
|
try:
|
|
- with tqdm(range(0, len(lines), line_count), desc=f"处理文件 {os.path.basename(file_path)}", unit="组") as pbar:
|
|
|
|
- for i in range(0, len(lines), line_count):
|
|
|
|
- group_index = i // line_count
|
|
|
|
|
|
+ with tqdm(range(start_line, len(lines)), desc=f"处理文件 {os.path.basename(file_path)}", unit="行") as pbar:
|
|
|
|
+ for i in range(start_line, len(lines)):
|
|
|
|
+ # 计算当前组的大小
|
|
|
|
+ current_group_size = calculate_group_size(len(lines[i]))
|
|
|
|
+ group_index = i // current_group_size
|
|
|
|
|
|
# 检查是否已完成
|
|
# 检查是否已完成
|
|
if group_index in completed_indices:
|
|
if group_index in completed_indices:
|
|
@@ -520,11 +646,11 @@ def process_html_file(file_path, conn):
|
|
if line[0] == group_index:
|
|
if line[0] == group_index:
|
|
translated_lines.extend(line[1].split('\n'))
|
|
translated_lines.extend(line[1].split('\n'))
|
|
break
|
|
break
|
|
- pbar.update(1)
|
|
|
|
|
|
+ pbar.update(current_group_size)
|
|
continue
|
|
continue
|
|
|
|
|
|
# 获取当前组的行
|
|
# 获取当前组的行
|
|
- group = lines[i:i+line_count]
|
|
|
|
|
|
+ group = lines[i:i+current_group_size]
|
|
if group:
|
|
if group:
|
|
try:
|
|
try:
|
|
# 收集需要翻译的段落
|
|
# 收集需要翻译的段落
|
|
@@ -539,7 +665,7 @@ def process_html_file(file_path, conn):
|
|
if paragraphs_to_translate:
|
|
if paragraphs_to_translate:
|
|
# 将所有需要翻译的段落合并成一个文本
|
|
# 将所有需要翻译的段落合并成一个文本
|
|
combined_text = "\n".join(paragraphs_to_translate)
|
|
combined_text = "\n".join(paragraphs_to_translate)
|
|
- logging.info(f"开始翻译第 {i+1}-{min(i+line_count, len(lines))} 行")
|
|
|
|
|
|
+ logging.info(f"开始翻译第 {i+1}-{min(i+current_group_size, len(lines))} 行")
|
|
translated_text = translate_text(combined_text)
|
|
translated_text = translate_text(combined_text)
|
|
|
|
|
|
# 分割翻译后的文本
|
|
# 分割翻译后的文本
|
|
@@ -563,7 +689,7 @@ def process_html_file(file_path, conn):
|
|
translated_lines.extend(translated_group)
|
|
translated_lines.extend(translated_group)
|
|
|
|
|
|
# 更新文件进度
|
|
# 更新文件进度
|
|
- processed_lines = min((group_index + 1) * line_count, total_lines)
|
|
|
|
|
|
+ processed_lines = min((group_index + 1) * current_group_size, total_lines)
|
|
db_manager.update_file_progress(file_path, total_lines, processed_lines, 'in_progress')
|
|
db_manager.update_file_progress(file_path, total_lines, processed_lines, 'in_progress')
|
|
|
|
|
|
# 显示当前统计信息
|
|
# 显示当前统计信息
|
|
@@ -579,7 +705,7 @@ def process_html_file(file_path, conn):
|
|
db_manager.log_error(file_path, group_index, "group_processing_error", str(e))
|
|
db_manager.log_error(file_path, group_index, "group_processing_error", str(e))
|
|
continue
|
|
continue
|
|
|
|
|
|
- pbar.update(1)
|
|
|
|
|
|
+ pbar.update(current_group_size)
|
|
|
|
|
|
# 替换原始内容
|
|
# 替换原始内容
|
|
if translated_lines:
|
|
if translated_lines:
|