// 初始化模型
  await PPOCRv5.init();
  print("✅ 模型初始化成功");

  final startTime = DateTime.now();
  print("⏱️ 开始识别:${startTime.toLocal()}");

  // 识别 assets 图片
  final result = await PPOCRv5.fromAsset('assets/images/test4.png');
  print("识别结果:$result");

  // 结束计时+打印
  final endTime = DateTime.now();
  final cost = endTime.difference(startTime);
  print("🏁 结束识别:${endTime.toLocal()}");
  print("⏱️ 耗时:${cost.inMilliseconds}ms / ${cost.inSeconds}s");
  print("结果:$result");
import 'dart:math' as math;
import 'dart:typed_data';

import 'package:flutter/services.dart' show rootBundle;
import 'package:flutter_onnxruntime/flutter_onnxruntime.dart';
import 'package:image/image.dart' as img;
import 'package:opencv_dart/opencv.dart' as cv;

class PPOCRArgs {
  String detModel = 'assets/models/det_infer.onnx';
  String recModel = 'assets/models/rec_infer.onnx';
  String dictPath = 'assets/models/ppocrv5_dict.txt';
  String? clsModel;

  int detLimitSideLen = 960;
  String detLimitType = 'max';
  double detDbThresh = 0.3;
  double detDbBoxThresh = 0.6;
  double detDbUnclipRatio = 1.5;
  bool useDilation = false;
  String detDbScoreMode = 'fast';
  String detBoxType = 'quad';

  List<int> recImageShape = [3, 48, 320];
  int recBatchNum = 6;
  bool useSpaceChar = true;
  double dropScore = 0.5;

  List<int> clsImageShape = [3, 48, 192];
  List<String> labelList = ['0', '180'];
  int clsBatchNum = 6;
  double clsThresh = 0.9;

  bool get useAngleCls => clsModel != null && clsModel!.isNotEmpty;
}

class _DetPreprocessResult {
  final Float32List input;
  final List<int> inputShape;
  final List<double> shape;

  _DetPreprocessResult({
    required this.input,
    required this.inputShape,
    required this.shape,
  });
}

class _RecDecodeResult {
  final String text;
  final double score;

  const _RecDecodeResult(this.text, this.score);
}

class _ClsDecodeResult {
  final String label;
  final double score;

  const _ClsDecodeResult(this.label, this.score);
}


class _ImageConverters {
  static cv.Mat imageToMat(img.Image image) {
    final rgb = image.hasPalette ? image.convert(numChannels: 3) : image;
    final bytes = rgb.getBytes(order: img.ChannelOrder.rgb);
    return cv.Mat.fromList(
      rgb.height,
      rgb.width,
      cv.MatType.CV_8UC3,
      bytes,
    );
  }

  static img.Image matToImage(cv.Mat mat) {
    final data = mat.toList3D();
    final image = img.Image(width: mat.cols, height: mat.rows);
    for (int y = 0; y < mat.rows; y++) {
      for (int x = 0; x < mat.cols; x++) {
        final pixel = data[y][x];
        image.setPixelRgb(
          x,
          y,
          pixel[0].toInt(),
          pixel[1].toInt(),
          pixel[2].toInt(),
        );
      }
    }
    return image;
  }
}

class _DetResizeForTest {
  final int limitSideLen;
  final String limitType;

  const _DetResizeForTest({required this.limitSideLen, required this.limitType});

  ({img.Image image, List<double> shape}) call(img.Image image) {
    final srcH = image.height;
    final srcW = image.width;

    final limitSide = limitType == 'max'
        ? math.max(srcH, srcW)
        : math.min(srcH, srcW);
    final ratio = limitSide > limitSideLen ? limitSideLen / limitSide : 1.0;

    int resizeH = (srcH * ratio).toInt();
    int resizeW = (srcW * ratio).toInt();

    resizeH = math.max(((resizeH / 32).round()) * 32, 32);
    resizeW = math.max(((resizeW / 32).round()) * 32, 32);

    final resized = img.copyResize(image, width: resizeW, height: resizeH);
    return (
      image: resized,
      shape: [
        srcH.toDouble(),
        srcW.toDouble(),
        resizeH / srcH,
        resizeW / srcW,
      ],
    );
  }
}

class _DBPostProcess {
  final double thresh;
  final double boxThresh;
  final int maxCandidates;
  final double unclipRatio;
  final bool useDilation;
  final String scoreMode;
  final String boxType;
  final int minSize = 3;

  const _DBPostProcess({
    required this.thresh,
    required this.boxThresh,
    required this.maxCandidates,
    required this.unclipRatio,
    required this.useDilation,
    required this.scoreMode,
    required this.boxType,
  });

  List<List<List<double>>> call({
    required List<double> pred,
    required int predH,
    required int predW,
    required List<double> shape,
  }) {
    final srcH = shape[0].toInt();
    final srcW = shape[1].toInt();

    final bitmapData = Uint8List(predH * predW);
    for (int i = 0; i < pred.length; i++) {
      bitmapData[i] = pred[i] > thresh ? 255 : 0;
    }

    var bitmap = cv.Mat.fromList(predH, predW, cv.MatType.CV_8UC1, bitmapData);
    if (useDilation) {
      final kernel = cv.Mat.fromList(2, 2, cv.MatType.CV_8UC1, [1, 1, 1, 1]);
      bitmap = cv.dilate(bitmap, kernel);
    }

    if (boxType == 'poly') {
      return _polygonsFromBitmap(pred, bitmap, predH, predW, srcW, srcH);
    }
    return _boxesFromBitmap(pred, bitmap, predH, predW, srcW, srcH);
  }

  List<List<List<double>>> _boxesFromBitmap(
    List<double> pred,
    cv.Mat bitmap,
    int predH,
    int predW,
    int destW,
    int destH,
  ) {
    final (contours, _) = cv.findContours(bitmap, cv.RETR_LIST, cv.CHAIN_APPROX_SIMPLE);
    final numContours = math.min(contours.length, maxCandidates);
    final boxes = <List<List<double>>>[];

    for (int i = 0; i < numContours; i++) {
      final contour = contours[i];
      final mini = _getMiniBoxes(contour);
      if (mini == null || mini.$2 < minSize) {
        continue;
      }

      final points = mini.$1;
      final score = scoreMode == 'fast'
          ? _boxScoreFast(pred, predH, predW, points)
          : _boxScoreSlow(pred, predH, predW, _vecPointToDoubleList(contour));
      if (score < boxThresh) {
        continue;
      }

      final expandedPaths = _unclip(points, unclipRatio);
      if (expandedPaths.length != 1) {
        continue;
      }
      final expanded = expandedPaths.first;
      final expandedVec = cv.VecPoint.fromList(
        expanded.map((e) => cv.Point(e[0].round(), e[1].round())).toList(),
      );
      final expandedMini = _getMiniBoxes(expandedVec);
      if (expandedMini == null || expandedMini.$2 < minSize + 2) {
        continue;
      }

      final box = expandedMini.$1
          .map(
            (p) => [
              _clipDouble((p[0] / predW * destW).roundToDouble(), 0, destW.toDouble()),
              _clipDouble((p[1] / predH * destH).roundToDouble(), 0, destH.toDouble()),
            ],
          )
          .toList();
      boxes.add(box);
    }

    return boxes;
  }

  List<List<List<double>>> _polygonsFromBitmap(
    List<double> pred,
    cv.Mat bitmap,
    int predH,
    int predW,
    int destW,
    int destH,
  ) {
    final (contours, _) = cv.findContours(bitmap, cv.RETR_LIST, cv.CHAIN_APPROX_SIMPLE);
    final boxes = <List<List<double>>>[];
    final limit = math.min(contours.length, maxCandidates);

    for (int i = 0; i < limit; i++) {
      final contour = contours[i];
      final epsilon = 0.002 * cv.arcLength(contour, true);
      final approx = cv.approxPolyDP(contour, epsilon, true);
      if (approx.length < 4) {
        continue;
      }

      final points = approx.map((e) => [e.x.toDouble(), e.y.toDouble()]).toList();
      final score = _boxScoreFast(pred, predH, predW, points);
      if (score < boxThresh) {
        continue;
      }

      final expandedPaths = _unclip(points, unclipRatio);
      if (expandedPaths.length != 1) {
        continue;
      }
      final expanded = expandedPaths.first;
      final expandedVec = cv.VecPoint.fromList(
        expanded.map((e) => cv.Point(e[0].round(), e[1].round())).toList(),
      );
      final mini = _getMiniBoxes(expandedVec);
      if (mini == null || mini.$2 < minSize + 2) {
        continue;
      }

      final scaled = expanded
          .map(
            (p) => [
              _clipDouble((p[0] / predW * destW).roundToDouble(), 0, destW.toDouble()),
              _clipDouble((p[1] / predH * destH).roundToDouble(), 0, destH.toDouble()),
            ],
          )
          .toList();
      boxes.add(scaled);
    }

    return boxes;
  }

  (List<List<double>>, double)? _getMiniBoxes(cv.VecPoint contour) {
    if (contour.length < 4) {
      return null;
    }
    final rotatedRect = cv.minAreaRect(contour);
    final points = cv.boxPoints(rotatedRect)
        .map((e) => [e.x, e.y])
        .toList()
      ..sort((a, b) => a[0].compareTo(b[0]));

    int index1;
    int index2;
    int index3;
    int index4;

    if (points[1][1] > points[0][1]) {
      index1 = 0;
      index4 = 1;
    } else {
      index1 = 1;
      index4 = 0;
    }
    if (points[3][1] > points[2][1]) {
      index2 = 2;
      index3 = 3;
    } else {
      index2 = 3;
      index3 = 2;
    }

    final box = [
      points[index1],
      points[index2],
      points[index3],
      points[index4],
    ];
    final shortSide = math.min(rotatedRect.size.width, rotatedRect.size.height);
    return (box, shortSide);
  }

  double _boxScoreFast(
    List<double> bitmap,
    int h,
    int w,
    List<List<double>> box,
  ) {
    int xmin = box.map((p) => p[0].floor()).reduce(math.min).clamp(0, w - 1);
    int xmax = box.map((p) => p[0].ceil()).reduce(math.max).clamp(0, w - 1);
    int ymin = box.map((p) => p[1].floor()).reduce(math.min).clamp(0, h - 1);
    int ymax = box.map((p) => p[1].ceil()).reduce(math.max).clamp(0, h - 1);

    final width = xmax - xmin + 1;
    final height = ymax - ymin + 1;
    if (width <= 0 || height <= 0) {
      return 0;
    }

    final mask = cv.Mat.fromScalar(height, width, cv.MatType.CV_8UC1, cv.Scalar());
    final shifted = box
        .map((p) => cv.Point((p[0] - xmin).round(), (p[1] - ymin).round()))
        .toList();
    cv.fillPoly(mask, cv.VecVecPoint.fromList([shifted]), cv.Scalar(1));

    final cropData = <double>[];
    for (int y = ymin; y <= ymax; y++) {
      for (int x = xmin; x <= xmax; x++) {
        cropData.add(bitmap[y * w + x]);
      }
    }
    final crop = cv.Mat.fromList(height, width, cv.MatType.CV_32FC1, cropData);
    return crop.mean(mask: mask).val1;
  }

  double _boxScoreSlow(
    List<double> bitmap,
    int h,
    int w,
    List<List<double>> contour,
  ) {
    int xmin = contour.map((p) => p[0]).reduce(math.min).floor().clamp(0, w - 1);
    int xmax = contour.map((p) => p[0]).reduce(math.max).ceil().clamp(0, w - 1);
    int ymin = contour.map((p) => p[1]).reduce(math.min).floor().clamp(0, h - 1);
    int ymax = contour.map((p) => p[1]).reduce(math.max).ceil().clamp(0, h - 1);

    final width = xmax - xmin + 1;
    final height = ymax - ymin + 1;
    if (width <= 0 || height <= 0) {
      return 0;
    }

    final mask = cv.Mat.fromScalar(height, width, cv.MatType.CV_8UC1, cv.Scalar());
    final shifted = contour
        .map((p) => cv.Point((p[0] - xmin).round(), (p[1] - ymin).round()))
        .toList();
    cv.fillPoly(mask, cv.VecVecPoint.fromList([shifted]), cv.Scalar(1));

    final cropData = <double>[];
    for (int y = ymin; y <= ymax; y++) {
      for (int x = xmin; x <= xmax; x++) {
        cropData.add(bitmap[y * w + x]);
      }
    }
    final crop = cv.Mat.fromList(height, width, cv.MatType.CV_32FC1, cropData);
    return crop.mean(mask: mask).val1;
  }

  List<List<double>> _vecPointToDoubleList(cv.VecPoint contour) {
    return contour.map((p) => [p.x.toDouble(), p.y.toDouble()]).toList();
  }

  List<List<List<double>>> _unclip(List<List<double>> box, double ratio) {
    final area = _polygonArea(box).abs();
    final perimeter = _polygonPerimeter(box);
    if (box.length < 3 || perimeter <= 0 || area <= 0) {
      return [];
    }

    final distance = area * ratio / perimeter;
    if (distance <= 0) {
      return [box];
    }

    final expanded = _offsetPolygon(box, distance);
    if (expanded.length < 3 || !_isFinitePolygon(expanded)) {
      return [];
    }
    return [expanded];
  }

  List<List<double>> _offsetPolygon(List<List<double>> polygon, double distance) {
    final area = _polygonArea(polygon);
    final sign = area >= 0 ? 1.0 : -1.0;
    final count = polygon.length;
    final result = <List<double>>[];

    for (int i = 0; i < count; i++) {
      final prev = polygon[(i - 1 + count) % count];
      final curr = polygon[i];
      final next = polygon[(i + 1) % count];

      final prevLine = _offsetLine(prev, curr, distance, sign);
      final nextLine = _offsetLine(curr, next, distance, sign);
      final intersection = _lineIntersection(
        prevLine.$1,
        prevLine.$2,
        nextLine.$1,
        nextLine.$2,
      );

      if (intersection != null && intersection[0].isFinite && intersection[1].isFinite) {
        result.add(intersection);
      } else {
        final n1 = _edgeNormal(prev, curr, sign);
        final n2 = _edgeNormal(curr, next, sign);
        final nx = n1[0] + n2[0];
        final ny = n1[1] + n2[1];
        final nl = math.sqrt(nx * nx + ny * ny);
        if (nl > 1e-6) {
          result.add([
            curr[0] + nx / nl * distance,
            curr[1] + ny / nl * distance,
          ]);
        } else {
          result.add([
            curr[0] + n1[0] * distance,
            curr[1] + n1[1] * distance,
          ]);
        }
      }
    }

    return result;
  }

  (List<double>, List<double>) _offsetLine(
    List<double> a,
    List<double> b,
    double distance,
    double sign,
  ) {
    final normal = _edgeNormal(a, b, sign);
    return (
      [a[0] + normal[0] * distance, a[1] + normal[1] * distance],
      [b[0] + normal[0] * distance, b[1] + normal[1] * distance],
    );
  }

  List<double> _edgeNormal(List<double> a, List<double> b, double sign) {
    final dx = b[0] - a[0];
    final dy = b[1] - a[1];
    final len = math.sqrt(dx * dx + dy * dy);
    if (len <= 1e-6) {
      return [0, 0];
    }
    if (sign >= 0) {
      return [dy / len, -dx / len];
    }
    return [-dy / len, dx / len];
  }

  List<double>? _lineIntersection(
    List<double> a1,
    List<double> a2,
    List<double> b1,
    List<double> b2,
  ) {
    final x1 = a1[0];
    final y1 = a1[1];
    final x2 = a2[0];
    final y2 = a2[1];
    final x3 = b1[0];
    final y3 = b1[1];
    final x4 = b2[0];
    final y4 = b2[1];

    final denom = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4);
    if (denom.abs() <= 1e-6) {
      return null;
    }

    final px = ((x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4)) / denom;
    final py = ((x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * (x3 * y4 - y3 * x4)) / denom;
    return [px, py];
  }

  bool _isFinitePolygon(List<List<double>> polygon) {
    for (final point in polygon) {
      if (point.length != 2 || !point[0].isFinite || !point[1].isFinite) {
        return false;
      }
    }
    return true;
  }

  double _polygonArea(List<List<double>> points) {
    double area = 0;
    for (int i = 0; i < points.length; i++) {
      final j = (i + 1) % points.length;
      area += points[i][0] * points[j][1] - points[j][0] * points[i][1];
    }
    return area / 2.0;
  }

  double _polygonPerimeter(List<List<double>> points) {
    double perimeter = 0;
    for (int i = 0; i < points.length; i++) {
      final j = (i + 1) % points.length;
      final dx = points[i][0] - points[j][0];
      final dy = points[i][1] - points[j][1];
      perimeter += math.sqrt(dx * dx + dy * dy);
    }
    return perimeter;
  }

  double _clipDouble(double value, double min, double max) {
    if (value < min) return min;
    if (value > max) return max;
    return value;
  }
}

class _TextDetector {
  final PPOCRArgs args;
  late final OrtSession session;
  late final String inputName;
  late final _DetResizeForTest _resize;
  late final _DBPostProcess _postProcess;

  _TextDetector(this.args);

  Future<void> init(OnnxRuntime runtime) async {
    session = await runtime.createSessionFromAsset(args.detModel);
    final inputs = await session.getInputInfo();
    if (inputs.isEmpty) {
      throw Exception('检测模型没有输入信息');
    }
    inputName = inputs.first['name']!;
    _resize = _DetResizeForTest(
      limitSideLen: args.detLimitSideLen,
      limitType: args.detLimitType,
    );
    _postProcess = _DBPostProcess(
      thresh: args.detDbThresh,
      boxThresh: args.detDbBoxThresh,
      maxCandidates: 1000,
      unclipRatio: args.detDbUnclipRatio,
      useDilation: args.useDilation,
      scoreMode: args.detDbScoreMode,
      boxType: args.detBoxType,
    );
  }

  Future<List<List<List<double>>>> detect(img.Image image) async {
    final pre = _preprocess(image);
    final inputTensor = await OrtValue.fromList(pre.input, pre.inputShape);
    final outputs = await session.run({inputName: inputTensor});
    await inputTensor.dispose();

    final output = outputs.values.first;
    final flattened = await output.asFlattenedList();
    final pred = flattened.map((e) => (e as num).toDouble()).toList();
    final shape = output.shape;
    if (shape.length < 4) {
      throw Exception('检测模型输出 shape 不符合预期: $shape');
    }

    final predH = shape[2];
    final predW = shape[3];
    final boxes = _postProcess(
      pred: pred,
      predH: predH,
      predW: predW,
      shape: pre.shape,
    );

    return args.detBoxType == 'poly'
        ? _filterTagDetResOnlyClip(boxes, image.height, image.width)
        : _filterTagDetRes(boxes, image.height, image.width);
  }

  _DetPreprocessResult _preprocess(img.Image image) {
    final resized = _resize(image);
    final normalized = _normalizeDetImage(resized.image);
    return _DetPreprocessResult(
      input: normalized,
      inputShape: [1, 3, resized.image.height, resized.image.width],
      shape: resized.shape,
    );
  }

  Float32List _normalizeDetImage(img.Image image) {
    const mean = [0.485, 0.456, 0.406];
    const std = [0.229, 0.224, 0.225];
    final out = Float32List(3 * image.height * image.width);
    int offset = 0;

    for (int c = 0; c < 3; c++) {
      for (int y = 0; y < image.height; y++) {
        for (int x = 0; x < image.width; x++) {
          final pixel = image.getPixel(x, y);
          final value = switch (c) {
            0 => pixel.r,
            1 => pixel.g,
            _ => pixel.b,
          };
          out[offset++] = ((value / 255.0) - mean[c]) / std[c];
        }
      }
    }
    return out;
  }

  List<List<List<double>>> _filterTagDetRes(
    List<List<List<double>>> dtBoxes,
    int imgHeight,
    int imgWidth,
  ) {
    final result = <List<List<double>>>[];
    for (final rawBox in dtBoxes) {
      final box = _clipDetRes(_orderPointsClockwise(rawBox), imgHeight, imgWidth);
      final rectWidth = _distance(box[0], box[1]).round();
      final rectHeight = _distance(box[0], box[3]).round();
      if (rectWidth <= 3 || rectHeight <= 3) {
        continue;
      }
      result.add(box);
    }
    return result;
  }

  List<List<List<double>>> _filterTagDetResOnlyClip(
    List<List<List<double>>> dtBoxes,
    int imgHeight,
    int imgWidth,
  ) {
    return dtBoxes.map((box) => _clipDetRes(box, imgHeight, imgWidth)).toList();
  }

  List<List<double>> _orderPointsClockwise(List<List<double>> pts) {
    final sortedBySum = [...pts]..sort((a, b) => (a[0] + a[1]).compareTo(b[0] + b[1]));
    final topLeft = sortedBySum.first;
    final bottomRight = sortedBySum.last;
    final remain = [...pts]
      ..remove(topLeft)
      ..remove(bottomRight);
    remain.sort((a, b) => ((a[1] - a[0])).compareTo(b[1] - b[0]));
    final topRight = remain.first;
    final bottomLeft = remain.last;
    return [topLeft, topRight, bottomRight, bottomLeft];
  }

  List<List<double>> _clipDetRes(List<List<double>> points, int imgHeight, int imgWidth) {
    return points
        .map(
          (p) => [
            p[0].clamp(0, imgWidth - 1).toDouble(),
            p[1].clamp(0, imgHeight - 1).toDouble(),
          ],
        )
        .toList();
  }

  double _distance(List<double> a, List<double> b) {
    final dx = a[0] - b[0];
    final dy = a[1] - b[1];
    return math.sqrt(dx * dx + dy * dy);
  }
}

class _CTCLabelDecoder {
  late final List<String> character;

  Future<void> init(String dictPath, {bool useSpaceChar = false}) async {
    final dictContent = await rootBundle.loadString(dictPath);
    final chars = dictContent
        .split('\n')
        .map((e) => e.endsWith('\r') ? e.substring(0, e.length - 1) : e)
        .where((e) => e.isNotEmpty)
        .toList();
    if (useSpaceChar) {
      chars.add(' ');
    }
    character = ['blank', ...chars];
  }

  List<_RecDecodeResult> decode(List<double> preds, List<int> shape) {
    if (shape.length != 3) {
      throw Exception('识别模型输出 shape 不符合预期: $shape');
    }

    final batch = shape[0];
    final timeSteps = shape[1];
    final classes = shape[2];
    final results = <_RecDecodeResult>[];

    for (int b = 0; b < batch; b++) {
      final indices = <int>[];
      final scores = <double>[];
      for (int t = 0; t < timeSteps; t++) {
        double maxScore = -double.infinity;
        int maxIdx = 0;
        for (int c = 0; c < classes; c++) {
          final index = b * timeSteps * classes + t * classes + c;
          final score = preds[index];
          if (score > maxScore) {
            maxScore = score;
            maxIdx = c;
          }
        }
        indices.add(maxIdx);
        scores.add(maxScore);
      }

      final chars = <String>[];
      final confs = <double>[];
      for (int i = 0; i < indices.length; i++) {
        final idx = indices[i];
        if (idx == 0) continue;
        if (i > 0 && indices[i] == indices[i - 1]) continue;
        if (idx >= 0 && idx < character.length) {
          chars.add(character[idx]);
          confs.add(scores[i]);
        }
      }

      final text = chars.join();
      final score = confs.isEmpty
          ? 0.0
          : confs.reduce((a, b) => a + b) / confs.length;
      results.add(_RecDecodeResult(text, score));
    }

    return results;
  }
}

class _ClsPostProcess {
  final List<String> labelList;

  const _ClsPostProcess(this.labelList);

  List<_ClsDecodeResult> decode(List<double> preds, List<int> shape) {
    if (shape.length != 2) {
      throw Exception('方向分类模型输出 shape 不符合预期: $shape');
    }
    final batch = shape[0];
    final classes = shape[1];
    final results = <_ClsDecodeResult>[];

    for (int b = 0; b < batch; b++) {
      double maxScore = -double.infinity;
      int maxIdx = 0;
      for (int c = 0; c < classes; c++) {
        final score = preds[b * classes + c];
        if (score > maxScore) {
          maxScore = score;
          maxIdx = c;
        }
      }
      results.add(_ClsDecodeResult(labelList[maxIdx], maxScore));
    }
    return results;
  }
}

class _TextRecognizer {
  final PPOCRArgs args;
  final _CTCLabelDecoder _decoder = _CTCLabelDecoder();
  late final OrtSession session;
  late final String inputName;

  _TextRecognizer(this.args);

  Future<void> init(OnnxRuntime runtime) async {
    session = await runtime.createSessionFromAsset(args.recModel);
    final inputs = await session.getInputInfo();
    if (inputs.isEmpty) {
      throw Exception('识别模型没有输入信息');
    }
    inputName = inputs.first['name']!;
    await _decoder.init(args.dictPath, useSpaceChar: args.useSpaceChar);
  }

  Future<List<_RecDecodeResult>> recognize(List<img.Image> imageList) async {
    if (imageList.isEmpty) {
      return const [];
    }

    final indexed = imageList.asMap().entries.toList()
      ..sort((a, b) {
        final wa = a.value.width / a.value.height;
        final wb = b.value.width / b.value.height;
        return wa.compareTo(wb);
      });

    final results = List<_RecDecodeResult>.filled(
      imageList.length,
      const _RecDecodeResult('', 0.0),
    );

    for (int beg = 0; beg < indexed.length; beg += args.recBatchNum) {
      final end = math.min(indexed.length, beg + args.recBatchNum);
      double maxWhRatio = 0;
      for (int i = beg; i < end; i++) {
        final image = indexed[i].value;
        maxWhRatio = math.max(maxWhRatio, image.width / image.height);
      }

      final batchData = Float32List((end - beg) * _singleImageElementCount(maxWhRatio));
      int offset = 0;
      for (int i = beg; i < end; i++) {
        final normalized = _resizeNormImg(indexed[i].value, maxWhRatio);
        batchData.setRange(offset, offset + normalized.length, normalized);
        offset += normalized.length;
      }

      final inputShape = [
        end - beg,
        args.recImageShape[0],
        args.recImageShape[1],
        math.max(args.recImageShape[2], (args.recImageShape[1] * maxWhRatio).ceil()),
      ];

      final inputTensor = await OrtValue.fromList(batchData, inputShape);
      final outputs = await session.run({inputName: inputTensor});
      await inputTensor.dispose();

      final output = outputs.values.first;
      final preds = (await output.asFlattenedList()).map((e) => (e as num).toDouble()).toList();
      final decoded = _decoder.decode(preds, output.shape);
      for (int i = 0; i < decoded.length; i++) {
        results[indexed[beg + i].key] = decoded[i];
      }
    }

    return results;
  }

  int _singleImageElementCount(double maxWhRatio) {
    final imgC = args.recImageShape[0];
    final imgH = args.recImageShape[1];
    final imgW = math.max(args.recImageShape[2], (imgH * maxWhRatio).ceil());
    return imgC * imgH * imgW;
  }

  Float32List _resizeNormImg(img.Image image, double maxWhRatio) {
    final imgC = args.recImageShape[0];
    final imgH = args.recImageShape[1];
    final imgW = math.max(args.recImageShape[2], (imgH * maxWhRatio).ceil());

    final ratio = image.width / image.height;
    final resizedW = math.min(imgW, (imgH * ratio).ceil());
    final resized = img.copyResize(image, width: resizedW, height: imgH);

    final output = Float32List(imgC * imgH * imgW);
    for (int c = 0; c < imgC; c++) {
      for (int y = 0; y < imgH; y++) {
        for (int x = 0; x < imgW; x++) {
          final outIndex = c * imgH * imgW + y * imgW + x;
          if (x >= resizedW) {
            output[outIndex] = 0;
            continue;
          }
          final pixel = resized.getPixel(x, y);
          final value = switch (c) {
            0 => pixel.r,
            1 => pixel.g,
            _ => pixel.b,
          };
          output[outIndex] = ((value / 255.0) - 0.5) / 0.5;
        }
      }
    }
    return output;
  }
}

class _TextClassifier {
  final PPOCRArgs args;
  final _ClsPostProcess _postProcess;
  late final OrtSession session;
  late final String inputName;

  _TextClassifier(this.args) : _postProcess = _ClsPostProcess(args.labelList);

  Future<void> init(OnnxRuntime runtime) async {
    final model = args.clsModel;
    if (model == null || model.isEmpty) {
      throw Exception('未配置方向分类模型');
    }
    session = await runtime.createSessionFromAsset(model);
    final inputs = await session.getInputInfo();
    if (inputs.isEmpty) {
      throw Exception('方向分类模型没有输入信息');
    }
    inputName = inputs.first['name']!;
  }

  Future<({List<img.Image> images, List<_ClsDecodeResult> results})> classify(List<img.Image> images) async {
    if (images.isEmpty) {
      return (images: <img.Image>[], results: <_ClsDecodeResult>[]);
    }

    final working = [...images];
    final indexed = working.asMap().entries.toList()
      ..sort((a, b) {
        final wa = a.value.width / a.value.height;
        final wb = b.value.width / b.value.height;
        return wa.compareTo(wb);
      });

    final clsRes = List<_ClsDecodeResult>.filled(
      working.length,
      const _ClsDecodeResult('', 0.0),
    );

    for (int beg = 0; beg < indexed.length; beg += args.clsBatchNum) {
      final end = math.min(indexed.length, beg + args.clsBatchNum);
      final imgC = args.clsImageShape[0];
      final imgH = args.clsImageShape[1];
      final imgW = args.clsImageShape[2];
      final batchData = Float32List((end - beg) * imgC * imgH * imgW);

      int offset = 0;
      for (int i = beg; i < end; i++) {
        final normalized = _resizeNormImg(indexed[i].value);
        batchData.setRange(offset, offset + normalized.length, normalized);
        offset += normalized.length;
      }

      final inputTensor = await OrtValue.fromList(batchData, [end - beg, imgC, imgH, imgW]);
      final outputs = await session.run({inputName: inputTensor});
      await inputTensor.dispose();

      final output = outputs.values.first;
      final preds = (await output.asFlattenedList()).map((e) => (e as num).toDouble()).toList();
      final decoded = _postProcess.decode(preds, output.shape);
      for (int i = 0; i < decoded.length; i++) {
        final originalIndex = indexed[beg + i].key;
        clsRes[originalIndex] = decoded[i];
        if (decoded[i].label.contains('180') && decoded[i].score > args.clsThresh) {
          working[originalIndex] = img.copyRotate(working[originalIndex], angle: 180);
        }
      }
    }

    return (images: working, results: clsRes);
  }

  Float32List _resizeNormImg(img.Image image) {
    final imgC = args.clsImageShape[0];
    final imgH = args.clsImageShape[1];
    final imgW = args.clsImageShape[2];

    final ratio = image.width / image.height;
    final resizedW = math.min(imgW, (imgH * ratio).ceil());
    final resized = img.copyResize(image, width: resizedW, height: imgH);

    final output = Float32List(imgC * imgH * imgW);
    for (int c = 0; c < imgC; c++) {
      for (int y = 0; y < imgH; y++) {
        for (int x = 0; x < imgW; x++) {
          final outIndex = c * imgH * imgW + y * imgW + x;
          if (x >= resizedW) {
            output[outIndex] = 0;
            continue;
          }
          final pixel = resized.getPixel(x, y);
          final value = switch (c) {
            0 => pixel.r,
            1 => pixel.g,
            _ => pixel.b,
          };
          output[outIndex] = ((value / 255.0) - 0.5) / 0.5;
        }
      }
    }
    return output;
  }
}

class PPOCRv5 {
  static final PPOCRArgs args = PPOCRArgs();
  static late final OnnxRuntime runtime;
  static late final _TextDetector _detector;
  static late final _TextRecognizer _recognizer;
  static _TextClassifier? _classifier;
  static bool _initialized = false;

  static Future<void> init() async {
    if (_initialized) return;
    runtime = OnnxRuntime();
    _detector = _TextDetector(args);
    _recognizer = _TextRecognizer(args);

    await _detector.init(runtime);
    await _recognizer.init(runtime);
    if (args.useAngleCls) {
      _classifier = _TextClassifier(args);
      await _classifier!.init(runtime);
    }
    _initialized = true;
  }

  static Future<List<Map<String, dynamic>>> recognize(
    img.Image image, {
    bool det = true,
    bool rec = true,
    bool cls = true,
  }) async {
    await init();

    if (det && rec) {
      final dtBoxes = await _detector.detect(image);
      if (dtBoxes.isEmpty) {
        return [];
      }

      final sorted = _sortedBoxes(dtBoxes);
      var crops = sorted.map((box) => _getRotateCropImage(image, box)).toList();
      if (args.useAngleCls && cls && _classifier != null) {
        final clsOut = await _classifier!.classify(crops);
        crops = clsOut.images;
      }

      final recRes = await _recognizer.recognize(crops);
      final ocrRes = <Map<String, dynamic>>[];
      for (int i = 0; i < sorted.length; i++) {
        final res = recRes[i];
        if (res.score >= args.dropScore) {
          ocrRes.add({
            'box': sorted[i]
                .map((p) => [p[0].round(), p[1].round()])
                .toList(),
            'text': res.text,
            'score': res.score,
          });
        }
      }
      return ocrRes;
    }

    if (det && !rec) {
      final dtBoxes = await _detector.detect(image);
      return dtBoxes
          .map(
            (box) => {
              'box': box.map((p) => [p[0].round(), p[1].round()]).toList(),
            },
          )
          .toList();
    }

    var images = [image];
    if (args.useAngleCls && cls && _classifier != null) {
      final clsOut = await _classifier!.classify(images);
      images = clsOut.images;
    }
    final recRes = await _recognizer.recognize(images);
    return recRes
        .map(
          (res) => {
            'text': res.text,
            'score': res.score,
          },
        )
        .toList();
  }

  static Future<List<Map<String, dynamic>>> fromAsset(String assetPath) async {
    final data = await rootBundle.load(assetPath);
    final image = img.decodeImage(data.buffer.asUint8List());
    if (image == null) {
      throw Exception('图像解码失败: $assetPath');
    }
    return recognize(image);
  }

  static List<List<List<double>>> _sortedBoxes(List<List<List<double>>> dtBoxes) {
    final boxes = [...dtBoxes]
      ..sort((a, b) {
        final byY = a[0][1].compareTo(b[0][1]);
        if (byY != 0) return byY;
        return a[0][0].compareTo(b[0][0]);
      });

    for (int i = 0; i < boxes.length - 1; i++) {
      for (int j = i; j >= 0; j--) {
        if ((boxes[j + 1][0][1] - boxes[j][0][1]).abs() < 10 &&
            boxes[j + 1][0][0] < boxes[j][0][0]) {
          final tmp = boxes[j];
          boxes[j] = boxes[j + 1];
          boxes[j + 1] = tmp;
        } else {
          break;
        }
      }
    }
    return boxes;
  }

  static img.Image _getRotateCropImage(img.Image image, List<List<double>> points) {
    if (points.length != 4) {
      throw ArgumentError('shape of points must be 4*2');
    }

    final cropWidth = math.max(
      _distance(points[0], points[1]),
      _distance(points[2], points[3]),
    ).toInt();
    final cropHeight = math.max(
      _distance(points[0], points[3]),
      _distance(points[1], points[2]),
    ).toInt();

    final safeWidth = math.max(cropWidth, 1);
    final safeHeight = math.max(cropHeight, 1);

    final src = cv.VecPoint2f.fromList(
      points.map((p) => cv.Point2f(p[0], p[1])).toList(),
    );
    final dst = cv.VecPoint2f.fromList([
      cv.Point2f(0, 0),
      cv.Point2f(safeWidth.toDouble(), 0),
      cv.Point2f(safeWidth.toDouble(), safeHeight.toDouble()),
      cv.Point2f(0, safeHeight.toDouble()),
    ]);

    final matrix = cv.getPerspectiveTransform2f(src, dst);
    var warped = cv.warpPerspective(
      _ImageConverters.imageToMat(image),
      matrix,
      (safeWidth, safeHeight),
      flags: cv.INTER_CUBIC,
      borderMode: cv.BORDER_REPLICATE,
    );

    if (warped.rows / warped.cols >= 1.5) {
      warped = warped.rotate(cv.ROTATE_90_COUNTERCLOCKWISE);
    }
    return _ImageConverters.matToImage(warped);
  }

  static double _distance(List<double> a, List<double> b) {
    final dx = a[0] - b[0];
    final dy = a[1] - b[1];
    return math.sqrt(dx * dx + dy * dy);
  }
}
最后修改:2026 年 04 月 06 日 06 : 49 PM
如果觉得我的文章对你有用,请随意赞赏