import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import re
import time
class GECTORErrorCorrectorM1:
def __init__(self, model_name="512duncanl/gec-flan-t5-large"):
"""初始化GECToR模型,适配M1 Mac GPU加速"""
# 设备优先级:MPS > CUDA > CPU
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
self.device = torch.device("mps")
elif torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
print(f"使用设备: {self.device}")
# 加载模型和tokenizer(无量化,避免MPS报错)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
self.model.eval() # 推理模式
def preprocess_text(self, text):
"""预处理:仅过滤句尾无意义数字,保留语义数字"""
text = re.sub(r"\s+\d+$", "", text) # 只删句尾冗余数字
text = re.sub(r",(\S)", r", \1", text) # 逗号后加空格
text = re.sub(r"\s+", " ", text).strip() # 合并多余空格
return text
def _get_diff(self, original, corrected):
"""对比文本差异,展示纠错点"""
import difflib
diff = difflib.ndiff(original.split(), corrected.split())
return [d for d in diff if d.startswith(('+', '-'))]
def batch_correct(self, text_list, batch_size=4):
"""M1 优化的批量推理函数(含单条耗时统计)"""
total_start_time = time.time() # 总耗时起始时间
results = []
processed_texts = [self.preprocess_text(t) for t in text_list]
# 分批次处理,统计每个批次的耗时
for i in range(0, len(processed_texts), batch_size):
batch_start = time.time() # 单个批次起始时间
batch_texts = processed_texts[i:i + batch_size]
batch_origs = text_list[i:i + batch_size]
# 批量编码
inputs = self.tokenizer(
batch_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=128
).to(self.device)
# 批量推理
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=128,
num_beams=5,
early_stopping=True,
use_cache=False # MPS内存优化
)
# 批量解码
corrected_batch = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
corrected_batch = [t.strip() for t in corrected_batch]
# 计算当前批次总耗时,均分至每条文本
batch_total_time = (time.time() - batch_start) * 1000 # 毫秒
single_text_time = round(batch_total_time / len(batch_texts), 2) # 单条耗时
# 生成结果(含单条耗时)
for j, (orig, processed, corrected) in enumerate(zip(batch_origs, batch_texts, corrected_batch)):
idx = i + j + 1
errors = []
if processed != corrected:
errors.append({
"original": processed,
"corrected": corrected,
"diff": self._get_diff(processed, corrected)
})
results.append({
"序号": idx,
"原始文本": orig,
"预处理后": processed,
"修复后文本": corrected,
"错误详情": errors if errors else "无错误",
"单条推理耗时(ms)": single_text_time # 新增单条耗时
})
# 统计总耗时
total_time = round((time.time() - total_start_time) * 1000, 2)
results.append({
"序号": "总计",
"原始文本": f"共{len(text_list)}条文本 | 批量大小{batch_size}",
"预处理后": "-",
"修复后文本": "-",
"错误详情": "-",
"单条推理耗时(ms)": f"总耗时 {total_time} ms"
})
return results
# ---------------------- 测试 M1 加速+单条耗时 ----------------------
if __name__ == "__main__":
# 初始化 M1 优化版纠错器
corrector = GECTORErrorCorrectorM1()
# 测试文本
test_texts = [
"I have 5 apples",
"And turn you into a frog. 4",
"Peppa and Suzy love playing in Peppa'a bedroom.",
"No,George. This game is just for big girls",
"I'm a tiny,little fair princess.",
"George,I need some help:",
"I'm making chocolate chip:cookies.",
"Hmmm,I think your heart's a bit loose.",
"Peppa and Suzy loves playing doctors and nurses.",
"He r might get tired.",
"Then I think I need Iots of cookies to make me better.",
"Im a tiny,little fair princess."
]
# 批量推理(batch_size=4)
correction_results = corrector.batch_correct(test_texts, batch_size=4)
# 输出结果(含单条耗时)
print("=" * 100)
print("GECToR M1 优化版 英文文本纠错结果(MPS加速+单条耗时统计)")
print("=" * 100)
for res in correction_results:
print(f"\n【序号】{res['序号']}")
print(f"原始文本:{res['原始文本']}")
if res['序号'] != "总计":
print(f"预处理后:{res['预处理后']}")
print(f"修复后:{res['修复后文本']}")
print(f"错误详情:{res['错误详情']}")
print(f"单条推理耗时:{res['单条推理耗时(ms)']} 毫秒")
else:
print(f"{res['单条推理耗时(ms)']}")
最后修改:2026 年 01 月 11 日 06 : 27 PM
© 允许规范转载

