import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import re
import time


class GECTORErrorCorrectorM1:
    def __init__(self, model_name="512duncanl/gec-flan-t5-large"):
        """初始化GECToR模型,适配M1 Mac GPU加速"""
        # 设备优先级:MPS > CUDA > CPU
        if torch.backends.mps.is_available() and torch.backends.mps.is_built():
            self.device = torch.device("mps")
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
        print(f"使用设备: {self.device}")

        # 加载模型和tokenizer(无量化,避免MPS报错)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
        self.model.eval()  # 推理模式

    def preprocess_text(self, text):
        """预处理:仅过滤句尾无意义数字,保留语义数字"""
        text = re.sub(r"\s+\d+$", "", text)  # 只删句尾冗余数字
        text = re.sub(r",(\S)", r", \1", text)  # 逗号后加空格
        text = re.sub(r"\s+", " ", text).strip()  # 合并多余空格
        return text

    def _get_diff(self, original, corrected):
        """对比文本差异,展示纠错点"""
        import difflib
        diff = difflib.ndiff(original.split(), corrected.split())
        return [d for d in diff if d.startswith(('+', '-'))]

    def batch_correct(self, text_list, batch_size=4):
        """M1 优化的批量推理函数(含单条耗时统计)"""
        total_start_time = time.time()  # 总耗时起始时间
        results = []
        processed_texts = [self.preprocess_text(t) for t in text_list]

        # 分批次处理,统计每个批次的耗时
        for i in range(0, len(processed_texts), batch_size):
            batch_start = time.time()  # 单个批次起始时间

            batch_texts = processed_texts[i:i + batch_size]
            batch_origs = text_list[i:i + batch_size]

            # 批量编码
            inputs = self.tokenizer(
                batch_texts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=128
            ).to(self.device)

            # 批量推理
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_length=128,
                    num_beams=5,
                    early_stopping=True,
                    use_cache=False  # MPS内存优化
                )

            # 批量解码
            corrected_batch = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
            corrected_batch = [t.strip() for t in corrected_batch]

            # 计算当前批次总耗时,均分至每条文本
            batch_total_time = (time.time() - batch_start) * 1000  # 毫秒
            single_text_time = round(batch_total_time / len(batch_texts), 2)  # 单条耗时

            # 生成结果(含单条耗时)
            for j, (orig, processed, corrected) in enumerate(zip(batch_origs, batch_texts, corrected_batch)):
                idx = i + j + 1
                errors = []
                if processed != corrected:
                    errors.append({
                        "original": processed,
                        "corrected": corrected,
                        "diff": self._get_diff(processed, corrected)
                    })
                results.append({
                    "序号": idx,
                    "原始文本": orig,
                    "预处理后": processed,
                    "修复后文本": corrected,
                    "错误详情": errors if errors else "无错误",
                    "单条推理耗时(ms)": single_text_time  # 新增单条耗时
                })

        # 统计总耗时
        total_time = round((time.time() - total_start_time) * 1000, 2)
        results.append({
            "序号": "总计",
            "原始文本": f"共{len(text_list)}条文本 | 批量大小{batch_size}",
            "预处理后": "-",
            "修复后文本": "-",
            "错误详情": "-",
            "单条推理耗时(ms)": f"总耗时 {total_time} ms"
        })
        return results


# ---------------------- 测试 M1 加速+单条耗时 ----------------------
if __name__ == "__main__":
    # 初始化 M1 优化版纠错器
    corrector = GECTORErrorCorrectorM1()

    # 测试文本
    test_texts = [
        "I have 5 apples",
        "And turn you into a frog. 4",
        "Peppa and Suzy love playing in Peppa'a bedroom.",
        "No,George. This game is just for big girls",
        "I'm a tiny,little fair princess.",
        "George,I need some help:",
        "I'm making chocolate chip:cookies.",
        "Hmmm,I think your heart's a bit loose.",
        "Peppa and Suzy loves playing doctors and nurses.",
        "He r might get tired.",
        "Then I think I need Iots of cookies to make me better.",
        "Im a tiny,little fair princess."
    ]

    # 批量推理(batch_size=4)
    correction_results = corrector.batch_correct(test_texts, batch_size=4)

    # 输出结果(含单条耗时)
    print("=" * 100)
    print("GECToR M1 优化版 英文文本纠错结果(MPS加速+单条耗时统计)")
    print("=" * 100)
    for res in correction_results:
        print(f"\n【序号】{res['序号']}")
        print(f"原始文本:{res['原始文本']}")
        if res['序号'] != "总计":
            print(f"预处理后:{res['预处理后']}")
            print(f"修复后:{res['修复后文本']}")
            print(f"错误详情:{res['错误详情']}")
            print(f"单条推理耗时:{res['单条推理耗时(ms)']} 毫秒")
        else:
            print(f"{res['单条推理耗时(ms)']}")
最后修改:2026 年 01 月 11 日 06 : 27 PM
如果觉得我的文章对你有用,请随意赞赏