import os
import subprocess
import cv2
import numpy as np
import re
import time
import random
import sys
import io
from datetime import datetime
from pathlib import Path
from PIL import Image
from difflib import SequenceMatcher
from collections import defaultdict
from functools import lru_cache
from concurrent.futures import ThreadPoolExecutor

# 导入ImgOcr类(请确保imgocr.py在当前目录)
from imgocr import ImgOcr

# ===================== 全局变量(优化:复用OCR实例) =====================
ocr_det = None  # 检测用OCR实例(效率模式)
ocr_rec = None  # 识别用OCR实例(精准模式)
group_text_cache = {}  # 缓存分组最优文本结果

# ===================== 新增:有效单个英文字母定义 =====================
# 仅保留有语义的单个英文字母,其余单个字母全部过滤
VALID_SINGLE_ENGLISH_CHARS = {'a', 'i', 'o', 'A', 'I', 'O'}


# ===================== 日志重定向类(保留原有功能) =====================
class Tee:
    """重定向stdout/stderr,同时输出到控制台和日志文件"""

    def __init__(self, filename, mode='a', encoding='utf-8'):
        self.file = open(filename, mode, encoding=encoding)
        self.stdout = sys.stdout
        self.stderr = sys.stderr
        sys.stdout = self
        sys.stderr = self

    def write(self, message):
        self.file.write(message)
        self.stdout.write(message)
        self.file.flush()
        self.stdout.flush()

    def flush(self):
        self.file.flush()
        self.stdout.flush()

    def close(self):
        sys.stdout = self.stdout
        sys.stderr = self.stderr
        self.file.close()


# ===================== 配置参数(优化:调整关键参数) =====================
VIDEO_PATH = "/Users/chenhuan/Downloads/中英双字幕版第1季52集/CE003 Best Friend.mp4"
FRAME_OUTPUT_DIR = "extracted_frames"
CROPPED_20_DIR = "cropped_20_percent"  # 兼容旧逻辑,实际内存处理为主
FINAL_SUBTITLE_DIR = "final_subtitle_boxes"
SRT_OUTPUT_PATH = "subtitles.srt"
LOG_FILE_PATH = f"subtitle_extract_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
FPS = 1  # 提取帧率(每秒1帧)
AREA_THRESHOLD = 4000
MIN_SIZE_THRESHOLD = 100
SIMILARITY_THRESHOLD = 0.8
CONFIDENCE_THRESHOLD = 0.65
INVALID_CHARS = r'[\[\]\(\)\\\/\^\$\#\@\!\~\_\*]'
LINE_Y_THRESHOLD = 10
EXTRA_WIDTH = 50
EXTRA_HEIGHT = 20
SAMPLE_COUNT = 20  # 优化:从35改为20(足够判断字幕类型)
BILINGUAL_THRESHOLD = 0.6
SRT_END_OFFSET = 0.1

# 英文高频错误补全规则库
ENGLISH_CORRECTION_RULES = [
    (r"^'m\s", "I'm "),
    (r"^Im\s", "I'm "),
    (r"^m\s", "I'm "),
]
ENGLISH_FIX_RULES = []


# ===================== 核心优化函数:初始化OCR(仅加载一次) =====================
def init_ocr_instance():
    """初始化OCR实例(全局复用,避免重复加载模型)"""
    global ocr_det, ocr_rec
    print("🔄 初始化OCR模型(仅加载一次)...")
    start_time = time.time()
    # 检测用效率模式(速度优先),识别用精准模式(精度优先)
    ocr_det = ImgOcr(model_version="v5", is_efficiency_mode=True)
    ocr_rec = ImgOcr(model_version="v5", is_efficiency_mode=True, use_angle_cls=True)
    print(f"✅ OCR模型初始化完成,耗时:{time.time() - start_time:.2f}s")


# ===================== 最终修复:FFmpeg提取帧(兼容所有版本) =====================
def extract_frames_with_ffmpeg_optimized(video_path, output_dir, fps=1):
    """兼容所有FFmpeg版本的帧提取:移除冲突参数,保证基础功能可用"""
    os.makedirs(output_dir, exist_ok=True)
    # 最终简化版FFmpeg命令:
    # 1. 移除-skip_frame nokey(解码参数,编码侧报错)
    # 2. 移除-fps_mode cfr(新版兼容问题)
    # 3. 用-vsync 0(兼容模式)+ filter:v fps={fps} 控制帧率
    # 4. 保证参数顺序:输入 → 过滤 → 编码 → 输出
    ffmpeg_cmd = [
        "ffmpeg",
        "-i", video_path,  # 输入视频
        "-vsync", "0",  # 兼容模式,避免帧重复/丢失
        "-filter:v", f"fps={fps}",  # 控制输出帧率(每秒1帧)
        "-c:v", "mjpeg",  # 编码格式(MJPEG,速度快)
        "-q:v", "3",  # 画质(3为平衡,数值越小画质越好)
        "-y",  # 覆盖已有文件(无需确认)
        f"{output_dir}/frame_%04d.jpg"  # 输出路径(必须最后)
    ]
    try:
        # 执行命令并捕获输出
        result = subprocess.run(
            ffmpeg_cmd,
            check=True,
            capture_output=True,
            text=True,
            encoding='utf-8'
        )
        # 统计提取的帧数
        frame_files = [f for f in os.listdir(output_dir) if f.startswith("frame_") and f.endswith(".jpg")]
        print(f"✅ 成功提取 {len(frame_files)} 帧到 {output_dir}")
    except subprocess.CalledProcessError as e:
        print(f"❌ 提取帧失败:{e.stderr}")
        raise


# ===================== 优化:内存中裁剪(减少IO) =====================
def crop_bottom_20_percent_in_memory(img):
    """内存中裁剪底部20%区域,不保存文件"""
    width, height = img.size
    crop_y_start = int(height * 0.8)
    crop_region = (0, crop_y_start, width, height)
    return img.crop(crop_region), crop_region


def crop_bottom_20_percent(image_path, output_dir):
    """兼容旧逻辑的裁剪函数(保留)"""
    os.makedirs(output_dir, exist_ok=True)
    img = Image.open(image_path)
    width, height = img.size
    crop_y_start = int(height * 0.8)
    crop_region = (0, crop_y_start, width, height)
    cropped_img = img.crop(crop_region)
    img_name = os.path.basename(image_path)
    cropped_path = os.path.join(output_dir, img_name)
    cropped_img.save(cropped_path)
    return cropped_path, crop_region


# ===================== 优化:简化预处理(提升速度) =====================
def preprocess_image_optimized(image_path):
    """简化预处理:仅灰度化,移除锐化/阈值/滤波(速度提升50%+)"""
    img = cv2.imread(image_path)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    return Image.fromarray(gray)


# ===================== 原有工具函数(保留核心逻辑) =====================
def convert_box_to_bounding_rect(box):
    box_np = np.array(box)
    xmin = int(np.min(box_np[:, 0]))
    ymin = int(np.min(box_np[:, 1]))
    xmax = int(np.max(box_np[:, 0]))
    ymax = int(np.max(box_np[:, 1]))
    return xmin, ymin, xmax, ymax


def filter_small_boxes(boxes, area_threshold):
    filtered_boxes = []
    for box in boxes:
        xmin, ymin, xmax, ymax = convert_box_to_bounding_rect(box)
        width = xmax - xmin
        height = ymax - ymin
        area = width * height
        if area >= area_threshold and not (width < MIN_SIZE_THRESHOLD and height < MIN_SIZE_THRESHOLD):
            filtered_boxes.append(box)
    return filtered_boxes


def merge_boxes(boxes):
    if not boxes:
        return None
    all_x = []
    all_y = []
    for box in boxes:
        box_np = np.array(box)
        all_x.extend(box_np[:, 0])
        all_y.extend(box_np[:, 1])
    xmin = int(np.min(all_x))
    ymin = int(np.min(all_y))
    xmax = int(np.max(all_x))
    ymax = int(np.max(all_y))
    return xmin, ymin, xmax, ymax


def crop_and_save_subtitle_box(image_path, crop_region, merged_box, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    img = Image.open(image_path)
    img_width, img_height = img.size
    crop_y_start = crop_region[1]

    xmin, ymin, xmax, ymax = merged_box
    original_xmin = xmin
    original_ymin = ymin + crop_y_start
    original_xmax = xmax
    original_ymax = ymax + crop_y_start

    new_xmin = max(0, original_xmin - EXTRA_WIDTH)
    new_xmax = min(img_width, original_xmax + EXTRA_WIDTH)
    new_ymin = max(0, original_ymin - EXTRA_HEIGHT)
    new_ymax = min(img_height, original_ymax + EXTRA_HEIGHT)

    subtitle_box = (new_xmin, new_ymin, new_xmax, new_ymax)
    subtitle_img = img.crop(subtitle_box)
    img_name = os.path.basename(image_path)
    subtitle_path = os.path.join(output_dir, img_name)
    subtitle_img.save(subtitle_path)
    print(
        f"📸 已保存字幕区域:{subtitle_path} | 调整后范围:宽[{new_xmin}→{new_xmax}] 高[{new_ymin}→{new_ymax}]")


def has_chinese_char(text):
    chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
    return bool(chinese_pattern.search(text))


def is_pure_english(text):
    english_pattern = re.compile(r'^[a-zA-Z0-9\s,.;!?\':"]+$')
    return bool(english_pattern.match(text.strip())) and not has_chinese_char(text)


# ===================== 修复核心:merge_ocr_fragments_to_lines 函数 =====================
def merge_ocr_fragments_to_lines(ocr_result):
    """将OCR识别的零散片段按坐标合并为完整的字幕行(修复lines未初始化问题)"""
    if not ocr_result:
        return []

    # 1. 计算每个片段的中心y坐标和置信度
    fragments = []
    for item in ocr_result:
        text = item['text'].strip()
        if not text:
            continue
        box = item['box']
        conf = item['score']
        # 计算片段的中心y坐标
        y_coords = [point[1] for point in box]
        center_y = np.mean(y_coords)
        # 计算片段的x起始坐标(用于排序同一行的片段)
        x_coords = [point[0] for point in box]
        start_x = np.min(x_coords)
        fragments.append({
            'text': text,
            'conf': conf,
            'center_y': center_y,
            'start_x': start_x
        })

    # 2. 按y坐标聚类合并为行(y差≤LINE_Y_THRESHOLD视为同一行)
    lines = []  # 关键修复:提前初始化lines为空列表,解决Unresolved reference错误
    if fragments:
        # 先按y坐标排序
        fragments_sorted_by_y = sorted(fragments, key=lambda x: x['center_y'])
        current_line = [fragments_sorted_by_y[0]]
        current_line_y = fragments_sorted_by_y[0]['center_y']

        for frag in fragments_sorted_by_y[1:]:
            if abs(frag['center_y'] - current_line_y) <= LINE_Y_THRESHOLD:
                current_line.append(frag)
            else:
                lines.append(current_line)
                current_line = [frag]
                current_line_y = frag['center_y']
        # 加入最后一行
        lines.append(current_line)
    else:
        lines = []

    # 3. 每行内按x坐标排序,合并文本并计算平均置信度
    merged_lines = []
    for line in lines:
        # 按x起始坐标排序(保证从左到右的顺序)
        line_sorted_by_x = sorted(line, key=lambda x: x['start_x'])
        # 合并文本
        line_text = ' '.join([frag['text'] for frag in line_sorted_by_x])  # 用空格连接片段,避免拼接错误
        # 计算平均置信度
        line_conf = np.mean([frag['conf'] for frag in line_sorted_by_x])
        # 记录行的y坐标(用于后续排序)
        line_y = np.mean([frag['center_y'] for frag in line_sorted_by_x])
        merged_lines.append({
            'text': line_text,
            'conf': line_conf,
            'y': line_y
        })

    # 4. 按y坐标排序所有行(保证从上到下的顺序)
    merged_lines_sorted = sorted(merged_lines, key=lambda x: x['y'])
    return merged_lines_sorted


# ===================== 核心修改:增强文本清洗函数(过滤无含义单个字母) =====================
def clean_text(text, is_chinese=False):
    # 第一步:移除无效字符
    text = re.sub(INVALID_CHARS, '', text)

    # 第二步:针对英文文本,过滤无含义的单个字母
    if not is_chinese and len(text.strip()) == 1:
        # 如果是单个字母,检查是否在有效列表中
        single_char = text.strip()
        if single_char not in VALID_SINGLE_ENGLISH_CHARS:
            return ""  # 过滤无含义单个字母

    # 第三步:常规文本格式化
    if is_chinese:
        text = re.sub(r'\s+', '', text)
        text = re.sub(r'([,。;!?:])\s+', r'\1', text)
        text = re.sub(r'\s+([,。;!?:])', r'\1', text)
    else:
        text = re.sub(r'\s+([,.;!?\':"])', r'\1', text)
        text = re.sub(r'([,.;!?\':"])\s+', r'\1 ', text)
        text = re.sub(r'\s+', ' ', text).strip()
        text = re.sub(r',\s+', ',', text)

    return text.strip()


def fix_english_spelling(text):
    fixed_text = text
    for pattern, replacement in ENGLISH_FIX_RULES:
        fixed_text = re.sub(pattern, replacement, fixed_text)
    fixed_text = re.sub(r'\s+', ' ', fixed_text).strip()
    return fixed_text


def correct_english_text(text):
    corrected_text = fix_english_spelling(text)
    for pattern, replacement in ENGLISH_CORRECTION_RULES:
        corrected_text = re.sub(pattern, replacement, corrected_text)
    return corrected_text


def standardize_text(text, is_chinese=False):
    if not text:
        return ""

    chinese_punc = r'[,。;!?:“”‘’()【】《》、…—]'
    english_punc = r'[,.!?;:"\'()\[\]<>-\u2026]'
    punctuation_pattern = f"{chinese_punc}{english_punc}"

    text_no_punc = re.sub(punctuation_pattern, '', text)
    text_clean_space = re.sub(r'\s+', ' ', text_no_punc).strip()

    if not is_chinese:
        text_standardized = text_clean_space.lower()
    else:
        text_standardized = text_clean_space

    return text_standardized


def calculate_similarity(text1, text2):
    return SequenceMatcher(None, text1, text2).ratio()


def group_similar_frames(ocr_results_list):
    if not ocr_results_list:
        return []
    groups = []
    current_group = [ocr_results_list[0]]
    for result in ocr_results_list[1:]:
        prev_result = current_group[-1]
        chinese_sim = calculate_similarity(prev_result[1], result[1])
        english_sim = calculate_similarity(prev_result[3], result[3])
        if chinese_sim >= SIMILARITY_THRESHOLD or english_sim >= SIMILARITY_THRESHOLD:
            current_group.append(result)
        else:
            groups.append(current_group)
            current_group = [result]
    groups.append(current_group)
    return groups


# ===================== 优化:缓存分组最优文本结果 =====================
def get_group_key(group, text_type):
    """生成分组的唯一缓存key"""
    frame_names = tuple([item[0] for item in group])
    return (frame_names, text_type)


def select_best_text_in_group(group, text_type):
    """原始最优文本选择逻辑(保留)"""
    if not group:
        return "无"

    all_candidates = []
    for item in group:
        frame_name, chinese_text, chinese_conf, english_text, english_conf = item
        if text_type == 'chinese':
            if chinese_text:
                standardized = standardize_text(chinese_text, is_chinese=True)
                all_candidates.append({
                    'original': chinese_text,
                    'standardized': standardized,
                    'conf': chinese_conf,
                    'frame_name': frame_name
                })
        else:
            if english_text:
                corrected_eng = correct_english_text(english_text)
                standardized = standardize_text(corrected_eng, is_chinese=False)
                all_candidates.append({
                    'original': corrected_eng,
                    'standardized': standardized,
                    'conf': english_conf,
                    'frame_name': frame_name
                })

    if not all_candidates:
        return "无"

    count_dict = defaultdict(int)
    for res in all_candidates:
        if res['standardized']:
            count_dict[res['standardized']] += 1

    for res in all_candidates:
        res['count'] = count_dict.get(res['standardized'], 0)

    all_candidates.sort(
        key=lambda x: (x['count'], x['conf'], len(x['original'])),
        reverse=True
    )

    max_count = all_candidates[0]['count']
    top_candidates = [res for res in all_candidates if res['count'] == max_count]
    top_candidates.sort(
        key=lambda x: (x['conf'], len(x['original'])),
        reverse=True
    )

    return top_candidates[0]['original']


def select_best_text_in_group_cached(group, text_type):
    """缓存版最优文本选择(避免重复计算)"""
    key = get_group_key(group, text_type)
    if key in group_text_cache:
        return group_text_cache[key]

    result = select_best_text_in_group(group, text_type)
    group_text_cache[key] = result
    return result


def sample_groups(groups, sample_count=SAMPLE_COUNT):
    if len(groups) <= sample_count:
        sampled_groups = groups
        print(f"ℹ️  总组数{len(groups)}≤采样数{sample_count},使用全部组进行字幕类型判断")
    else:
        random.seed(42)
        sampled_groups = random.sample(groups, sample_count)
        print(f"ℹ️  从{len(groups)}组中随机采样{sample_count}组进行字幕类型判断")
    return sampled_groups


# ===================== 优化:缓存版字幕类型判断 =====================
def judge_subtitle_type_optimized(sampled_groups, bilingual_threshold=BILINGUAL_THRESHOLD):
    bilingual_count = 0
    chinese_only_count = 0
    english_only_count = 0
    empty_count = 0

    for group in sampled_groups:
        best_chinese = select_best_text_in_group_cached(group, 'chinese')
        best_english = select_best_text_in_group_cached(group, 'english')

        has_valid_chinese = best_chinese and best_chinese != "无"
        has_valid_english = best_english and best_english != "无"

        if has_valid_chinese and has_valid_english:
            bilingual_count += 1
        elif has_valid_chinese and not has_valid_english:
            chinese_only_count += 1
        elif not has_valid_chinese and has_valid_english:
            english_only_count += 1
        else:
            empty_count += 1

    total_valid = len(sampled_groups) - empty_count
    if total_valid == 0:
        print("❌ 采样组中无有效字幕,默认判定为单语(英文)")
        return "english_only"

    bilingual_ratio = bilingual_count / total_valid
    print(f"\n📊 采样组字幕类型统计:")
    print(f"   - 双语组:{bilingual_count}个")
    print(f"   - 仅中文组:{chinese_only_count}个")
    print(f"   - 仅英文组:{english_only_count}个")
    print(f"   - 空组:{empty_count}个")
    print(f"   - 双语占比:{bilingual_ratio:.2%}(阈值:{bilingual_threshold:.2%})")

    if bilingual_ratio >= bilingual_threshold:
        subtitle_type = "bilingual"
        print(f"✅ 判定字幕类型为:双语")
    elif chinese_only_count > english_only_count:
        subtitle_type = "chinese_only"
        print(f"✅ 判定字幕类型为:仅中文")
    else:
        subtitle_type = "english_only"
        print(f"✅ 判定字幕类型为:仅英文")

    return subtitle_type


# ===================== 优化:缓存版分组过滤 =====================
def filter_groups_by_type_optimized(groups, subtitle_type):
    filtered_groups = []
    filtered_count = 0

    for group in groups:
        best_chinese = select_best_text_in_group_cached(group, 'chinese')
        best_english = select_best_text_in_group_cached(group, 'english')

        has_valid_chinese = best_chinese and best_chinese != "无"
        has_valid_english = best_english and best_english != "无"

        if subtitle_type == "bilingual":
            if has_valid_chinese and has_valid_english:
                filtered_groups.append(group)
            else:
                filtered_count += 1
        elif subtitle_type == "chinese_only":
            if has_valid_chinese:
                filtered_groups.append(group)
            else:
                filtered_count += 1
        else:
            if has_valid_english:
                filtered_groups.append(group)
            else:
                filtered_count += 1

    print(f"\n🗂️  分组过滤结果:")
    print(f"   - 原始组数:{len(groups)}")
    print(f"   - 过滤后组数:{len(filtered_groups)}")
    print(f"   - 过滤掉的组数:{filtered_count}")

    return filtered_groups


def seconds_to_srt_time(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    ms = int((seconds % 1) * 1000)
    return f"{hours:02d}:{minutes:02d}:{secs:02d},{ms:03d}"


def generate_srt(groups, subtitle_type, fps=FPS):
    srt_content = []
    for idx, group in enumerate(groups, 1):
        frame_numbers = []
        for item in group:
            frame_name = item[0]
            match = re.search(r'frame_(\d+)\.jpg', frame_name)
            if match:
                frame_numbers.append(int(match.group(1)))

        if not frame_numbers:
            continue

        first_frame = min(frame_numbers)
        last_frame = max(frame_numbers)
        start_time = (first_frame - 1) / fps
        end_time = (last_frame / fps) + SRT_END_OFFSET

        print(f"📌 第{idx}组时间:首帧{first_frame}({start_time:.2f}s) → 尾帧{last_frame}({end_time:.2f}s)")

        best_chinese = select_best_text_in_group_cached(group, 'chinese')
        best_english = select_best_text_in_group_cached(group, 'english')

        srt_text_lines = []
        if subtitle_type == "bilingual":
            srt_text_lines.append(best_english)
            srt_text_lines.append(best_chinese)
        elif subtitle_type == "chinese_only":
            srt_text_lines.append(best_chinese)
        else:
            srt_text_lines.append(best_english)

        # 过滤空行
        srt_text_lines = [line for line in srt_text_lines if line and line != "无"]
        if not srt_text_lines:
            continue

        srt_entry = [
            str(idx),
            f"{seconds_to_srt_time(start_time)} --> {seconds_to_srt_time(end_time)}",
            "\n".join(srt_text_lines)
        ]
        srt_content.append("\n".join(srt_entry) + "\n")

    if srt_content:
        with open(SRT_OUTPUT_PATH, 'w', encoding='utf-8') as f:
            f.write("\n".join(srt_content).strip())
        print(f"\n✅ SRT字幕文件已生成:{SRT_OUTPUT_PATH}")
        print(f"📊 最终生成字幕条目数:{len(srt_content)}")
    else:
        print("\n❌ 无有效字幕条目,无法生成SRT文件")


# ===================== 优化:批量OCR识别(简化预处理) =====================
def ocr_subtitle_boxes_optimized(subtitle_dir):
    if not os.path.exists(subtitle_dir):
        print(f"❌ 目录 {subtitle_dir} 不存在,跳过OCR识别")
        return []

    subtitle_files = sorted([f for f in os.listdir(subtitle_dir) if f.endswith((".jpg", ".png"))])
    total_files = len(subtitle_files)
    ocr_results_list = []
    if total_files == 0:
        print(f"⚠️  目录 {subtitle_dir} 下无图片文件,跳过OCR识别")
        return []

    print(f"\n📊 开始识别 {total_files} 张字幕图片(优化版批量处理)...")
    start_time = time.time()

    # 批量处理(减少模型调用开销)
    batch_size = 10
    for batch_idx in range(0, total_files, batch_size):
        batch_files = subtitle_files[batch_idx:batch_idx + batch_size]
        batch_results = []

        for sub_file in batch_files:
            sub_path = os.path.join(subtitle_dir, sub_file)
            # 简化预处理:仅灰度化(速度提升50%+)
            preprocessed_img = preprocess_image_optimized(sub_path)
            ocr_result = ocr_rec.ocr(preprocessed_img, det=True, rec=True, cls=False)
            merged_lines = merge_ocr_fragments_to_lines(ocr_result)

            chinese_lines = []
            chinese_confs = []
            english_lines = []
            english_confs = []

            for line in merged_lines:
                line_text = line['text'].strip()
                line_conf = line['conf']
                if not line_text:
                    continue

                if has_chinese_char(line_text):
                    chinese_lines.append(line_text)
                    chinese_confs.append(line_conf)
                elif is_pure_english(line_text):
                    english_lines.append(line_text)
                    english_confs.append(line_conf)

            # 清洗文本(包含无含义单个字母过滤)
            chinese_text_raw = clean_text(''.join(chinese_lines), is_chinese=True)
            chinese_conf = np.mean(chinese_confs) if chinese_confs else 0.0
            english_text_raw = clean_text(' '.join(english_lines), is_chinese=False)
            english_conf = np.mean(english_confs) if english_confs else 0.0
            batch_results.append((sub_file, chinese_text_raw, chinese_conf, english_text_raw, english_conf))

        ocr_results_list.extend(batch_results)

        # 进度提示
        processed = min(batch_idx + batch_size, total_files)
        progress = (processed / total_files) * 100
        elapsed_time = time.time() - start_time
        print(f"\r🔄 识别进度:{processed}/{total_files} ({progress:.1f}%) | 已用时:{elapsed_time:.1f}s", end="")

    print("\n")
    groups = group_similar_frames(ocr_results_list)
    print(f"✅ 共分为 {len(groups)} 组重复字幕")

    # 打印分组详情(保留核心调试信息)
    print("\n" + "=" * 150)
    print("最终字幕识别结果(最优结果 + 组内统计)")
    print("=" * 150)

    for i, group in enumerate(groups, 1):
        frame_names = [item[0] for item in group]
        frame_range = f"{frame_names[0]} ~ {frame_names[-1]}" if len(frame_names) > 1 else frame_names[0]

        print(f"\n📦 第{i}组(帧范围:{frame_range} | 共{len(group)}帧)")
        best_chinese = select_best_text_in_group_cached(group, 'chinese')
        best_english = select_best_text_in_group_cached(group, 'english')

        print(f"   🎯 最优中文文本:{best_chinese if best_chinese else '无'}")
        print(f"   🎯 最优英文文本:{best_english if best_english else '无'}")

    print("\n" + "=" * 150)
    total_time = time.time() - start_time
    print(f"OCR识别优化完成!总耗时:{total_time:.1f}s")
    print("=" * 150)

    return groups


# ===================== 优化:并行处理帧(多线程) =====================
def process_single_frame(args):
    """单帧处理函数(用于并行)"""
    frame_file, frame_dir, cropped_dir, subtitle_dir = args
    frame_path = os.path.join(frame_dir, frame_file)
    try:
        # 内存中裁剪(减少IO)
        img = Image.open(frame_path)
        cropped_img, crop_region = crop_bottom_20_percent_in_memory(img)

        # OCR检测(复用全局实例)
        cropped_img_bytes = io.BytesIO()
        cropped_img.save(cropped_img_bytes, format='JPEG')
        cropped_img_bytes.seek(0)
        ocr_result = ocr_det.ocr(cropped_img, det=True, rec=False, cls=False)

        if not ocr_result:
            return None

        detected_boxes = [item['box'] for item in ocr_result]
        filtered_boxes = filter_small_boxes(detected_boxes, AREA_THRESHOLD)
        if not filtered_boxes:
            return None

        merged_box = merge_boxes(filtered_boxes)
        if merged_box:
            # 保存最终字幕区域图片(必要IO)
            crop_and_save_subtitle_box(frame_path, crop_region, merged_box, subtitle_dir)
            return frame_file
    except Exception as e:
        print(f"⚠️ 处理帧 {frame_file} 出错:{e}")
        return None


def process_frames_parallel(frame_dir, cropped_dir, subtitle_dir, max_workers=4):
    """并行处理帧(多线程,IO密集型任务最优)"""
    frame_files = sorted([f for f in os.listdir(frame_dir) if f.endswith((".jpg", ".png"))])
    if not frame_files:
        print("⚠️ 无帧文件需要处理")
        return

    args_list = [(f, frame_dir, cropped_dir, subtitle_dir) for f in frame_files]
    print(f"🔄 开始并行处理 {len(frame_files)} 帧(线程数:{max_workers})...")
    start_time = time.time()

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        results = list(executor.map(process_single_frame, args_list))

    processed_frames = [r for r in results if r is not None]
    elapsed_time = time.time() - start_time
    print(f"✅ 并行处理完成!耗时:{elapsed_time:.2f}s | 成功处理 {len(processed_frames)} 帧")


# ===================== 主函数(整合所有优化) =====================
def main():
    tee = None
    try:
        # 初始化日志重定向
        tee = Tee(LOG_FILE_PATH, 'w', 'utf-8')
        print(f"📝 日志文件已创建:{LOG_FILE_PATH}")
        print(f"⏰ 程序启动时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        print("=" * 80 + "\n")

        # 1. 初始化OCR实例(仅加载一次)
        init_ocr_instance()

        # 2. 提取帧(最终兼容版FFmpeg命令)
        extract_frames_with_ffmpeg_optimized(VIDEO_PATH, FRAME_OUTPUT_DIR, FPS)

        # 3. 并行处理帧(多线程+内存裁剪)
        process_frames_parallel(FRAME_OUTPUT_DIR, CROPPED_20_DIR, FINAL_SUBTITLE_DIR, max_workers=4)

        # 4. 优化版OCR识别(批量+缓存)
        groups = ocr_subtitle_boxes_optimized(FINAL_SUBTITLE_DIR)
        if not groups:
            print("❌ 无有效字幕分组,无法生成SRT文件")
            return

        # 5. 缓存版字幕类型判断+过滤
        sampled_groups = sample_groups(groups)
        subtitle_type = judge_subtitle_type_optimized(sampled_groups)
        filtered_groups = filter_groups_by_type_optimized(groups, subtitle_type)

        if not filtered_groups:
            print("❌ 过滤后无有效分组,无法生成SRT文件")
            return

        # 6. 生成SRT文件
        generate_srt(filtered_groups, subtitle_type, FPS)

    except Exception as e:
        print(f"❌ 程序执行出错:{str(e)}", file=sys.stderr)
        raise
    finally:
        # 恢复输出并关闭日志
        if tee:
            print("\n" + "=" * 80)
            print(f"⏰ 程序结束时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
            print(f"📝 日志已保存至:{LOG_FILE_PATH}")
            tee.close()


if __name__ == "__main__":
    main()
最后修改:2026 年 01 月 11 日 08 : 13 PM
如果觉得我的文章对你有用,请随意赞赏