控制器

var captchaRequest request.VerifyCodeRequest

if err := ctx.ShouldBindWith(&captchaRequest, binding.JSON); err != nil {
    helper.Info(err.Error())
    ReturnError(ctx, defined.ERROR_500)
    return
}

var jsonStr, err = service.CaptchaAESDecrypt(captchaRequest.Ciphertext)
if err != nil {
    helper.Error(err)
    ReturnSuccess(ctx, nil)
    return
}

// 1. 初始化 RawData 变量
var rawData service.CaptchaRawData

// 2. 解析 JSON 字符串到结构体
// Unmarshal 要求 JSON 字段类型与结构体字段类型匹配(如数字 vs int64,字符串 vs string)
err = json.Unmarshal([]byte(jsonStr), &rawData)
if err != nil {
    ReturnError(ctx, defined.ERROR_500)
    return
}

service.CaptchaConsole(rawData, rawData.Env.SessionId, ctx)

ReturnSuccess(ctx, rawData)

服务

package service

import (
    "activity/cache"
    "activity/dao"
    "activity/etc"
    "activity/helper"
    "activity/model"
    "activity/third/idc"
    "activity/third/ip2region"
    "activity/utils"
    "crypto/aes"
    "crypto/cipher"
    "encoding/hex"
    "errors"
    "fmt"
    "github.com/gin-gonic/gin"
    "github.com/google/uuid"
    log "github.com/sirupsen/logrus"
    "gonum.org/v1/gonum/stat"
    "math"
)

// 定义常量:统一阈值,便于维护
const (
    // 通用规则阈值(无类型差异的基础规则)
    DeviceIPBindCountMaxNormal   = 2  // 正常IP绑定数上限
    DeviceIPBindCountMinRisk     = 3  // 风险IP绑定数下限
    DeviceIPBindCountMaxRisk     = 4  // 风险IP绑定数上限
    DeviceIPBindCountMinAbnormal = 5  // 异常IP绑定数下限
    IPVerifyFreqMinRisk          = 10 // 风险验证频率下限
    IPVerifyFreqMaxRisk          = 19 // 风险验证频率上限
    IPVerifyFreqMinAbnormal      = 20 // 异常验证频率下限

    // ------------ 按验证类型划分的 load_to_first_delay 阈值(ms)------------
    // 滑动(swipe):操作简单,延迟最短
    SwipeDelayMinRisk     = 80   // 滑动-风险延迟下限(略短)
    SwipeDelayMinNormal   = 200  // 滑动-正常延迟下限
    SwipeDelayMaxNormal   = 2000 // 滑动-正常延迟上限
    SwipeDelayMaxRisk     = 3000 // 滑动-风险延迟上限(略长)
    SwipeDelayMinAbnormal = 50   // 滑动-异常延迟下限(过短)
    SwipeDelayMaxAbnormal = 4000 // 滑动-异常延迟上限(过长)

    // 文字点选(click):需识别目标,延迟中等
    ClickDelayMinRisk     = 100  // 点选-风险延迟下限
    ClickDelayMinNormal   = 300  // 点选-正常延迟下限
    ClickDelayMaxNormal   = 7000 // 点选-正常延迟上限
    ClickDelayMaxRisk     = 8000 // 点选-风险延迟上限
    ClickDelayMinAbnormal = 70   // 点选-异常延迟下限
    ClickDelayMaxAbnormal = 9000 // 点选-异常延迟上限

    // 文字输入(input):需阅读+输入,延迟最长
    InputDelayMinRisk     = 200   // 输入-风险延迟下限
    InputDelayMinNormal   = 500   // 输入-正常延迟下限
    InputDelayMaxNormal   = 5000  // 输入-正常延迟上限
    InputDelayMaxRisk     = 7000  // 输入-风险延迟上限
    InputDelayMinAbnormal = 100   // 输入-异常延迟下限
    InputDelayMaxAbnormal = 10000 // 输入-异常延迟上限(允许更长犹豫时间)

    // 其他类型(drag等):默认中等延迟
    DefaultDelayMinRisk     = 150  // 默认-风险延迟下限
    DefaultDelayMinNormal   = 350  // 默认-正常延迟下限
    DefaultDelayMaxNormal   = 3000 // 默认-正常延迟上限
    DefaultDelayMaxRisk     = 4000 // 默认-风险延迟上限
    DefaultDelayMinAbnormal = 80   // 默认-异常延迟下限
    DefaultDelayMaxAbnormal = 5000 // 默认-异常延迟上限

    // Click类型阈值(其他原有阈值保持不变)
    ClickCVLowAbnormal  = 0.08 // 点击CV过低(异常)
    ClickCVHighAbnormal = 0.85 // 点击CV过高(异常)
    ClickCVNormalMin    = 0.09 // 点击CV正常下限

    // Swipe类型阈值(单位:ms,统一数值与注释)
    SwipeIntervalMinAbnormal = 2  // 滑动间隔过低(异常)
    SwipeIntervalMaxAbnormal = 12 // 滑动间隔过高(异常)
    SwipeIntervalMinRiskFast = 2  // 滑动略快下限(风险)
    SwipeIntervalMaxRiskFast = 3  // 滑动略快上限(风险)
    SwipeIntervalMinRiskSlow = 10 // 滑动略慢下限(风险)
    SwipeIntervalMaxRiskSlow = 12 // 滑动略慢上限(风险)
    SwipeIntervalMinNormal   = 3  // 滑动正常下限
    SwipeIntervalMaxNormal   = 10 // 滑动正常上限

    // Input类型阈值(单位:ms)
    InputIntervalMinAbnormal = 150  // 输入间隔过低(异常)
    InputIntervalMaxAbnormal = 2500 // 输入间隔过高(异常)
    InputIntervalMinRiskFast = 150  // 输入略快下限(风险)
    InputIntervalMaxRiskFast = 400  // 输入略快上限(风险)
    InputIntervalMinRiskSlow = 2000 // 输入略慢下限(风险)
    InputIntervalMaxRiskSlow = 2500 // 输入略慢上限(风险)
    InputIntervalMinNormal   = 400  // 输入正常下限
    InputIntervalMaxNormal   = 2000 // 输入正常上限
)

func CaptchaConsole(raw CaptchaRawData, sessionId string, req *gin.Context) {

    if !validateTimestamps(raw) {
        helper.Warning("验证码校验数据异常")
        return
    }

    var feature, err = ExtractFeatures(&raw, sessionId, req)

    if err != nil {
        helper.Error("验证码收集错误", err.Error())
        return
    }

    swipeData := map[string]interface{}{
        "verify_type":          feature.VerifyType,
        "os":                   feature.OS,
        "device_type":          feature.DeviceType,
        "ip_type":              feature.IPType,
        "load_to_first_delay":  feature.LoadToFirstDelay,
        "total_duration":       feature.TotalDuration,
        "avg_swipe_interval":   feature.AvgSwipeInterval,
        "ip_verify_freq":       feature.IPVerifyFreq,
        "device_ip_bind_count": feature.DeviceIPBindCount,
    }

    swipeResult, err := utils.CaptchaPredict.PredictCaptcha(swipeData)
    if err != nil {
        log.Fatalf("❌ 验证预测失败: %v", err)
    }

    helper.Info("预测模型", swipeResult.AbnormalProbability, swipeResult.LabelDescription)

    // 3. 分数用途:日志打印/人工复核参考/缓存等
    helper.Printf("样本[%s] 反作弊得分:%d,标签:%s", feature.SampleID, feature.TempScore, feature.Label)
    dao.GetPlatformDB().Model(&model.CaptchaVerifyFeature{}).Save(&feature)

}

// validateTimestamps 校验时间戳字段是否符合验证类型要求
func validateTimestamps(raw CaptchaRawData) bool {
    switch raw.Env.VerifyType {
    case "swipe", "click": // 滑动/点选验证
        if len(raw.Timing.OperateTimestamps) == 0 {
            log.Printf("样本[%s]:滑动/点选验证缺少operateTimestamps", raw.Env.SessionId)
            return false
        }
    case "input", "sms", "text": // 输入类验证
        if len(raw.Timing.InputTimestamps) == 0 {
            log.Printf("样本[%s]:输入类验证缺少inputTimestamps", raw.Env.SessionId)
            return false
        }
    case "none": // 无感验证
        return true // 无要求
    default: // 未知验证类型
        log.Printf("样本[%s]:未知验证类型[%s]", raw.Env.SessionId, raw.Env.VerifyType)
        return false
    }
    return true
}

// ========== 结构体定义(对应原始数据结构) ==========
// Env 原始数据中的环境信息
type Env struct {
    VerifyType        string `json:"verifyType"`        // 验证码类型
    DeviceFingerprint string `json:"deviceFingerprint"` // 设备指纹
    OS                string `json:"os"`                // 操作系统
    DeviceType        string `json:"deviceType"`        // 设备类型
    // 新增字段
    Browser       string `json:"browser"`       // 浏览器(含版本)
    CookieEnabled bool   `json:"cookieEnabled"` // 是否启用Cookie
    PluginCount   int    `json:"pluginCount"`   // 插件数量
    ScreenRes     string `json:"screenRes"`     // 屏幕分辨率
    SessionId     string `json:"sessionId"`     // 会话ID
    UserAgent     string `json:"userAgent"`     // 用户代理(加密/哈希后)
}

// Timing 原始数据中的时间统计信息
type Timing struct {
    FirstOperateTime  int64   `json:"firstOperateTime"`  // 首次操作时间戳
    LoadTime          int64   `json:"loadTime"`          // 加载时间戳
    SubmitTime        int64   `json:"submitTime"`        // 提交时间戳
    OperateTimestamps []int64 `json:"operateTimestamps"` // 操作时间戳(滑动/点击/拖拽)
    InputTimestamps   []int64 `json:"inputTimestamps"`   // 输入时间戳(输入型验证码)
    CorrectCount      int     `json:"correctCount"`      // 输入正确次数
}

// RawData 原始数据根结构
type CaptchaRawData struct {
    Env    Env    `json:"env"`
    Timing Timing `json:"timing"`
}

// Request 简化的请求结构体(按需扩展)
type Request struct {
    IP string `json:"ip"` // IP地址(用于IP类型/频率查询)
}

// ========== 外部依赖函数(占位,需按业务实现) ==========
// getIPType 第三方接口获取IP类型(IPv4/IPv6/代理等)
func getIPType(req *gin.Context, realIP string) string {
    // 实际实现:调用IP解析接口(如ip2region/阿里云IP库)
    var tag, err = idc.Search(realIP)
    if err != nil {
        return "异常"
    }

    if tag != "" {
        return "IDC"
    }

    if etc.CommonConfigVar.Common.Env != "test" {
        if helper.IsViaUntrustedProxy(req) {
            return "代理"
        }
    }

    _, ipIfo, _ := ip2region.Search(realIP)

    if ipIfo.ISP == "" {
        return "未知"
    }

    return ipIfo.ISP
}

// getDeviceIPBindCount redis查询设备绑定7天的IP数量
func getDeviceIPBindCount(deviceFP string, realIP string) int {
    // 实际实现:查询设备指纹-IP绑定表
    var key = fmt.Sprintf(cache.CAPTCHA_VERIFY_DEVICE_IP, deviceFP)
    cache.Cache.SAddWithExpire(key, realIP, 3600*24*7)
    var count, _ = cache.Cache.SCard(key)
    return count
}

// getIPVerifyFreq redis查询IP近期4小时验证频率
func getIPVerifyFreq(realIP string) int {

    var key = fmt.Sprintf(cache.CAPTCHA_VERIFY_IP_FREQ, realIP)
    count, _ := cache.Cache.IncrWithExpire(key, 3600)
    return int(count)

}

// DelayThreshold 按验证类型获取 load_to_first_delay 的阈值结构体
type DelayThreshold struct {
    MinAbnormal int // 异常延迟下限(低于此值为异常)
    MaxAbnormal int // 异常延迟上限(高于此值为异常)
    MinRisk     int // 风险延迟下限(高于异常下限、低于正常下限)
    MaxRisk     int // 风险延迟上限(高于正常上限、低于异常上限)
    MinNormal   int // 正常延迟下限
    MaxNormal   int // 正常延迟上限
}

// getDelayThresholdByType 根据验证类型获取对应的延迟阈值
func getDelayThresholdByType(verifyType string) DelayThreshold {
    switch verifyType {
    case "swipe":
        return DelayThreshold{
            MinAbnormal: SwipeDelayMinAbnormal,
            MaxAbnormal: SwipeDelayMaxAbnormal,
            MinRisk:     SwipeDelayMinRisk,
            MaxRisk:     SwipeDelayMaxRisk,
            MinNormal:   SwipeDelayMinNormal,
            MaxNormal:   SwipeDelayMaxNormal,
        }
    case "click":
        return DelayThreshold{
            MinAbnormal: ClickDelayMinAbnormal,
            MaxAbnormal: ClickDelayMaxAbnormal,
            MinRisk:     ClickDelayMinRisk,
            MaxRisk:     ClickDelayMaxRisk,
            MinNormal:   ClickDelayMinNormal,
            MaxNormal:   ClickDelayMaxNormal,
        }
    case "input":
        return DelayThreshold{
            MinAbnormal: InputDelayMinAbnormal,
            MaxAbnormal: InputDelayMaxAbnormal,
            MinRisk:     InputDelayMinRisk,
            MaxRisk:     InputDelayMaxRisk,
            MinNormal:   InputDelayMinNormal,
            MaxNormal:   InputDelayMaxNormal,
        }
    default: // drag 等其他类型
        return DelayThreshold{
            MinAbnormal: DefaultDelayMinAbnormal,
            MaxAbnormal: DefaultDelayMaxAbnormal,
            MinRisk:     DefaultDelayMinRisk,
            MaxRisk:     DefaultDelayMaxRisk,
            MinNormal:   DefaultDelayMinNormal,
            MaxNormal:   DefaultDelayMaxNormal,
        }
    }
}

func autoLabel(features *model.CaptchaVerifyFeature, timing *Timing) (string, int) {
    // 入参校验:特征为空时,返回risk+0分
    if features == nil {
        return "risk", 0
    }

    // 消除未使用参数警告
    _ = timing

    // 初始化反作弊基础得分(100分)
    score := 100
    verifyType := features.VerifyType
    // 获取当前类型的专属延迟阈值
    delayThreshold := getDelayThresholdByType(verifyType)
    // 当前样本的加载到首次操作延迟(转换为int类型,解决类型不匹配)
    currentDelay := int(features.LoadToFirstDelay) // 关键修改:int64 -> int

    // ========== 第一步:通用异常规则(触发直接返回abnormal) ==========
    // 规则1:加载到首次操作延迟 低于类型专属异常下限(过短,非人类操作)
    if currentDelay < delayThreshold.MinAbnormal { // 现在都是int类型,可直接比较
        score -= 30
        return "abnormal", score
    }
    // 规则2:加载到首次操作延迟 高于类型专属异常上限(过长,疑似机器卡顿时效)
    if currentDelay > delayThreshold.MaxAbnormal { // 类型一致
        score -= 25
        return "abnormal", score
    }
    // 规则3:设备绑定IP数 ≥5(多IP作弊)
    if features.DeviceIPBindCount >= DeviceIPBindCountMinAbnormal {
        score -= 20
        return "abnormal", score
    }
    // 规则4:IP验证频率 ≥20(高频验证作弊)
    if features.IPVerifyFreq >= IPVerifyFreqMinAbnormal {
        score -= 15
        return "abnormal", score
    }
    // 规则5:IP类型为高危类型(代理/异常等)
    if helper.Contains([]string{"异常", "代理", "未知", "IDC"}, features.IPType) {
        score -= 50
        return "abnormal", score
    }

    // ========== 第二步:按验证码类型判定专属异常规则 ==========
    // 标记是否触发专属异常规则
    hasExclusiveAbnormal := false

    switch verifyType {
    case "click": // 文字点选
        // 负类规则:点击间隔变异系数超出异常区间
        if features.ClickIntervalCV != -1 {
            if features.ClickIntervalCV < ClickCVLowAbnormal || features.ClickIntervalCV > ClickCVHighAbnormal {
                score -= 25
                hasExclusiveAbnormal = true
            } else if features.ClickIntervalCV >= ClickCVLowAbnormal && features.ClickIntervalCV < ClickCVNormalMin {
                // 模糊规则:接近异常区间,先扣分(不直接返回)
                score -= 10
            }
        }

    case "swipe": // 滑动验证
        // 负类规则:平均滑动间隔 <2ms 或 >12ms(机械操作/异常慢滑)
        if features.AvgSwipeInterval != -1 &&
            (features.AvgSwipeInterval < SwipeIntervalMinAbnormal || features.AvgSwipeInterval > SwipeIntervalMaxAbnormal) {
            score -= 25
            hasExclusiveAbnormal = true
        } else if features.AvgSwipeInterval != -1 {
            // 模糊规则:2~3ms(略快)或 10~12ms(略慢)
            if (features.AvgSwipeInterval >= SwipeIntervalMinRiskFast && features.AvgSwipeInterval < SwipeIntervalMaxRiskFast) ||
                (features.AvgSwipeInterval >= SwipeIntervalMinRiskSlow && features.AvgSwipeInterval <= SwipeIntervalMaxRiskSlow) {
                score -= 8
            }
        }

    case "input": // 文字输入
        // 负类规则:平均输入间隔 <150ms(自动填充)或 >2500ms(异常慢)
        if features.AvgInputInterval != -1 &&
            (features.AvgInputInterval < InputIntervalMinAbnormal || features.AvgInputInterval > InputIntervalMaxAbnormal) {
            score -= 25
            hasExclusiveAbnormal = true
        } else if features.AvgInputInterval != -1 {
            // 模糊规则:150~400ms(略快)或 2000~2500ms(略慢)
            if (features.AvgInputInterval >= InputIntervalMinRiskFast && features.AvgInputInterval < InputIntervalMaxRiskFast) ||
                (features.AvgInputInterval >= InputIntervalMinRiskSlow && features.AvgInputInterval <= InputIntervalMaxRiskSlow) {
                score -= 8
            }
        }
    }

    // 触发专属异常规则,直接返回abnormal
    if hasExclusiveAbnormal {
        return "abnormal", score
    }

    // ========== 第三步:正类规则(真实用户特征,触发返回normal) ==========
    // 正类基础条件:加载延迟正常 + 设备绑定IP数正常(类型一致)
    baseNormalCondition := currentDelay >= delayThreshold.MinNormal &&
        currentDelay <= delayThreshold.MaxNormal &&
        features.DeviceIPBindCount <= DeviceIPBindCountMaxNormal

    // 按验证码类型判定正类
    switch verifyType {
    case "click":
        if baseNormalCondition && features.ClickIntervalCV != -1 && features.ClickIntervalCV > ClickCVNormalMin {
            score += 10
            return "normal", score
        }
    case "swipe":
        if baseNormalCondition && features.AvgSwipeInterval != -1 &&
            features.AvgSwipeInterval >= SwipeIntervalMinNormal && features.AvgSwipeInterval <= SwipeIntervalMaxNormal {
            score += 10
            return "normal", score
        }
    case "input":
        if baseNormalCondition && features.AvgInputInterval != -1 &&
            features.AvgInputInterval >= InputIntervalMinNormal && features.AvgInputInterval <= InputIntervalMaxNormal {
            score += 10
            return "normal", score
        }
    default:
        // 其他类型(如drag):仅判断基础条件
        if baseNormalCondition {
            score += 10
            return "normal", score
        }
    }

    // ========== 第四步:核心风险判定(基于类型专属延迟区间) ==========
    // 风险判定核心条件:满足任一即归为risk(类型一致)
    var isRisk bool
    // 条件1:延迟在「异常下限~正常下限」之间(略短)
    delayRiskLow := currentDelay >= delayThreshold.MinAbnormal && currentDelay < delayThreshold.MinNormal
    // 条件2:延迟在「正常上限~异常上限」之间(略长)
    delayRiskHigh := currentDelay > delayThreshold.MaxNormal && currentDelay <= delayThreshold.MaxAbnormal
    // 条件3:设备绑定IP数超出正常区间(3-4个)
    ipBindRisk := features.DeviceIPBindCount >= DeviceIPBindCountMinRisk && features.DeviceIPBindCount <= DeviceIPBindCountMaxRisk
    // 条件4:IP验证频率超出正常区间(10-19次)
    ipFreqRisk := features.IPVerifyFreq >= IPVerifyFreqMinRisk && features.IPVerifyFreq <= IPVerifyFreqMaxRisk

    // 满足任一风险条件,标记为risk
    isRisk = delayRiskLow || delayRiskHigh || ipBindRisk || ipFreqRisk

    // ========== 第五步:风险样本微调扣分(按延迟类型差异化扣分) ==========
    if isRisk {
        // 延迟略短:根据类型调整扣分(滑动操作对延迟敏感,扣分更多)
        if delayRiskLow {
            switch verifyType {
            case "swipe":
                score -= 8 // 滑动延迟略短,疑似脚本
            case "click":
                score -= 6 // 点选延迟略短,中等风险
            case "input":
                score -= 4 // 输入延迟略短,风险较低
            default:
                score -= 5
            }
        }
        // 延迟略长:输入类型允许更长犹豫,扣分更少
        if delayRiskHigh {
            switch verifyType {
            case "swipe":
                score -= 6 // 滑动延迟略长,疑似卡顿/机器
            case "click":
                score -= 5 // 点选延迟略长,中等风险
            case "input":
                score -= 3 // 输入延迟略长,正常犹豫
            default:
                score -= 4
            }
        }
        // 设备绑定IP数略高:中度扣分
        if ipBindRisk {
            score -= 8
        }
        // IP验证频率略高:轻微扣分
        if ipFreqRisk {
            score -= 5
        }
        // 风险样本直接返回(兜底前的明确判定)
        return "risk", score
    }

    // ========== 兜底规则:未触发任何规则的边缘样本,仍归为risk ==========
    return "risk", score
}

// ========== 核心特征提取函数 ==========
// ExtractFeatures 提取并加工特征
func ExtractFeatures(rawData *CaptchaRawData, sessionID string, req *gin.Context) (*model.CaptchaVerifyFeature, error) {
    // 1. 提取原始基础数据
    env := rawData.Env
    timing := rawData.Timing
    verifyType := env.VerifyType
    realIP := helper.GetIp(req)

    helper.Info("ip:", realIP)

    // 2. 构建基础特征
    features := &model.CaptchaVerifyFeature{
        SampleID:          uuid.NewString(), // 生成UUID(对应Python uuid.uuid4())
        SessionID:         sessionID,
        VerifyType:        verifyType,
        LoadToFirstDelay:  timing.FirstOperateTime - timing.LoadTime,
        TotalDuration:     timing.SubmitTime - timing.LoadTime,
        DeviceFingerprint: env.DeviceFingerprint,
        OS:                env.OS,
        DeviceType:        env.DeviceType,
        IPType:            getIPType(req, realIP),
        DeviceIPBindCount: getDeviceIPBindCount(env.DeviceFingerprint, realIP),
        IPVerifyFreq:      getIPVerifyFreq(realIP),
        Ip:                realIP,
    }

    // 3. 按验证码类型添加专属特征
    switch verifyType {
    case "swipe", "click", "drag":
        operateTS := timing.OperateTimestamps
        features.OperateStepCount = len(operateTS)

        if len(operateTS) > 1 {
            // 计算操作时间间隔(对应Python np.diff)
            intervals := make([]float64, len(operateTS)-1)
            for i := 0; i < len(operateTS)-1; i++ {
                intervals[i] = float64(operateTS[i+1] - operateTS[i])
            }

            switch verifyType {
            case "click":
                // 计算点击间隔变异系数(标准差/均值,除零保护)
                mean := stat.Mean(intervals, nil)
                if mean == 0 {
                    features.ClickIntervalCV = 0
                } else {
                    std := math.Sqrt(stat.Variance(intervals, nil))
                    features.ClickIntervalCV = std / mean
                }
            case "swipe":
                // 计算平均滑动间隔(转int)
                mean := stat.Mean(intervals, nil)
                features.AvgSwipeInterval = int(mean)
            }
        }

    case "input":
        inputTS := timing.InputTimestamps
        features.InputCount = len(inputTS)

        // 计算正确率(除零保护)
        if len(inputTS) > 0 {
            features.CorrectRate = (float64(timing.CorrectCount) / float64(len(inputTS))) * 100
        } else {
            features.CorrectRate = 0
        }

        // 计算平均输入间隔
        if len(inputTS) > 1 {
            intervals := make([]float64, len(inputTS)-1)
            for i := 0; i < len(inputTS)-1; i++ {
                intervals[i] = float64(inputTS[i+1] - inputTS[i])
            }
            mean := stat.Mean(intervals, nil)
            features.AvgInputInterval = int(mean)
        }
    }

    // 4. 自动标注标签
    features.Label, features.TempScore = autoLabel(features, &timing)

    return features, nil
}

// 配置参数(与前端保持完全一致)
const (
    appKey = "53D8DBC5DIK3436A" // 加密密钥
    appIv  = "KI5JL2SKE9883365" // 加密IV(不再从key截取,直接使用独立的appIv)
)

// PKCS7Unpad 去除PKCS7填充(CryptoJS的Pkcs7对应标准PKCS7填充)
// AES块大小固定为16字节,填充规则:填充n个字节则每个字节值为n
func pKCS7Unpad(data []byte) ([]byte, error) {
    length := len(data)
    if length == 0 {
        return nil, errors.New("填充数据为空")
    }
    // 获取最后一个字节的填充长度
    unpadLen := int(data[length-1])
    // 验证填充长度合法性(必须在1~16之间)
    if unpadLen < 1 || unpadLen > aes.BlockSize {
        return nil, errors.New("无效的填充长度")
    }
    // 验证所有填充字节是否符合PKCS7规则
    for i := 0; i < unpadLen; i++ {
        if data[length-1-i] != byte(unpadLen) {
            return nil, errors.New("填充数据不合法")
        }
    }
    // 去除填充部分,返回原始数据
    return data[:length-unpadLen], nil
}

// AESDecrypt 解密前端加密的数据(适配最新的JS逻辑:IV使用appIv)
// hexCiphertext: 前端返回的ciphertext(Hex编码字符串)
// 返回值:解密后的原始数据字节、错误信息
func CaptchaAESDecrypt(hexCiphertext string) ([]byte, error) {
    // 1. 解码Hex字符串为密文字节(对应前端的toString(CryptoJS.enc.Hex))
    cipherData, err := hex.DecodeString(hexCiphertext)
    if err != nil {
        return nil, fmt.Errorf("Hex解码失败: %w", err)
    }

    // 2. 处理密钥和IV(与最新JS逻辑严格对齐)
    // 密钥:appKey的UTF8字节(对应CryptoJS.enc.Utf8.parse(appKey))
    key := []byte(appKey)
    // IV:appIv的UTF8字节(对应CryptoJS.enc.Utf8.parse(appIv))
    iv := []byte(appIv)

    // 验证IV长度(AES-CBC要求IV必须等于块大小16字节)
    if len(iv) != aes.BlockSize {
        return nil, fmt.Errorf("IV长度非法,要求16字节,实际%d字节", len(iv))
    }

    // 3. 创建AES密码块
    block, err := aes.NewCipher(key)
    if err != nil {
        return nil, fmt.Errorf("创建AES块失败: %w", err)
    }

    // 4. 验证密文长度(必须是AES块大小的整数倍)
    if len(cipherData)%aes.BlockSize != 0 {
        return nil, errors.New("密文长度不是AES块大小的整数倍")
    }

    // 5. 创建CBC模式解密器(对应前端的CryptoJS.mode.CBC)
    mode := cipher.NewCBCDecrypter(block, iv)

    // 6. 执行解密(CBC模式解密会直接修改目标切片)
    plainData := make([]byte, len(cipherData))
    mode.CryptBlocks(plainData, cipherData)

    // 7. 去除PKCS7填充(对应前端的CryptoJS.pad.Pkcs7)
    plainData, err = pKCS7Unpad(plainData)
    if err != nil {
        return nil, fmt.Errorf("去除填充失败: %w", err)
    }

    return plainData, nil
}

训练

import pandas as pd
import numpy as np
import warnings
import matplotlib.pyplot as plt
import optuna
import os
import pickle
import json  # 新增JSON库
import platform
from typing import Dict, List, Optional, Tuple
import time
import re

# 替换为XGBoost相关库
import xgboost as xgb
from xgboost.callback import TrainingCallback  # 导入回调基类
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, roc_auc_score
from imblearn.over_sampling import SMOTE

# 忽略警告
warnings.filterwarnings('ignore')


# ====================== 修复字体异常(核心) ======================
# 适配不同系统的中文字体,解决SimHei找不到问题
def set_matplotlib_chinese_font():
    system = platform.system()
    try:
        if system == 'Darwin':  # macOS
            plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'PingFang SC']
        elif system == 'Windows':  # Windows
            plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei']
        elif system == 'Linux':  # Linux
            plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'WenQuanYi Micro Hei']
        plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示
        print(f"✅ 字体设置完成(系统:{system}),当前字体:{plt.rcParams['font.sans-serif']}")
    except Exception as e:
        print(f"⚠️ 字体设置失败:{e},使用默认字体")


# 执行字体设置
set_matplotlib_chinese_font()

# 开启Matplotlib交互式模式(实时绘图)
plt.ion()

# ====================== 新增:差异化配置 ======================
# 不同verify_type的差异化参数配置
VERIFY_TYPE_CONFIGS = {
    'swipe': {
        # Swipe(滑动验证)参数:更关注时间序列特征
        'model_params': {
            'max_depth': 5,
            'learning_rate': 0.015,
            'subsample': 0.75,
            'colsample_bytree': 0.75,
            'reg_alpha': 0.2,
            'reg_lambda': 0.2,
            'min_child_weight': 4,
            'gamma': 0.01,
            'n_estimators': 250
        },
        'core_features': ['load_to_first_delay', 'total_duration', 'avg_swipe_interval','ip_verify_freq','device_ip_bind_count'],
        'smote_strategy': 0.7,  # 过采样比例
        'extreme_quantile': (0.02, 0.98)  # 极端值过滤分位数
    },
    'click': {
        # Click(点击验证)参数:更关注点击行为特征
        'model_params': {
            'max_depth': 4,
            'learning_rate': 0.02,
            'subsample': 0.7,
            'colsample_bytree': 0.7,
            'reg_alpha': 0.3,
            'reg_lambda': 0.3,
            'min_child_weight': 6,
            'gamma': 0.03,
            'n_estimators': 200
        },
        'core_features': ['load_to_first_delay', 'total_duration', 'click_interval_cv','ip_verify_freq','device_ip_bind_count'],
        'smote_strategy': 0.8,  # 过采样比例
        'extreme_quantile': (0.01, 0.99)  # 极端值过滤分位数
    },
    'default': {
        # 默认配置
        'model_params': {
            'max_depth': 4,
            'learning_rate': 0.02,
            'subsample': 0.7,
            'colsample_bytree': 0.7,
            'reg_alpha': 0.3,
            'reg_lambda': 0.3,
            'min_child_weight': 5,
            'gamma': 0.02,
            'n_estimators': 200
        },
        'core_features': ['load_to_first_delay', 'total_duration', 'ip_verify_freq','device_ip_bind_count'],
        'smote_strategy': 0.8,
        'extreme_quantile': (0.01, 0.99)
    }
}

# 需要过滤的无用列(包含字符串ID、IP、时间等)
USELESS_COLS = [
    'id', 'sample_id', 'session_id', 'ip', 'device_fingerprint',
    'temp_score', 'created_at', 'updated_at', 'user_id', 'device_id',
    'phone_number', 'email', 'cookie_id', 'session_token'
]

# 全局变量:存储不同verify_type的训练过程数据
train_metrics = {
    'swipe': {
        'train_loss': [], 'val_loss': [], 'train_auc': [], 'val_auc': [],
        'fig': None, 'ax1': None, 'ax2': None, 'ax3': None, 'ax4': None,
        'lines': {}, 'core_feature_importance': {}, 'importance_iterations': []
    },
    'click': {
        'train_loss': [], 'val_loss': [], 'train_auc': [], 'val_auc': [],
        'fig': None, 'ax1': None, 'ax2': None, 'ax3': None, 'ax4': None,
        'lines': {}, 'core_feature_importance': {}, 'importance_iterations': []
    }
}

# 全局变量:存储不同verify_type的预处理参数
preprocessing_params = {
    'swipe': {},
    'click': {},
    'default': {}
}

# 初始化各verify_type的核心特征重要性存储
for vt in ['swipe', 'click']:
    config = VERIFY_TYPE_CONFIGS.get(vt, VERIFY_TYPE_CONFIGS['default'])
    core_features = config['core_features']
    for feat in core_features:
        train_metrics[vt]['core_feature_importance'][feat] = []


# ====================== 新增:XGBoost模型转JSON工具函数 ======================
def convert_xgb_model_to_json(model: xgb.Booster, output_path: str):
    """
    将XGBoost模型转换为标准JSON格式(兼容Go解析)

    Args:
        model: 训练好的XGBoost Booster模型
        output_path: JSON文件输出路径
    """
    try:
        # 1. 导出模型为JSON字符串(XGBoost原生JSON格式)
        model_json_str = model.save_raw("json")

        # 2. 解析为Python字典(确保格式规范)
        model_json = json.loads(model_json_str)

        # 3. 保存为格式化的JSON文件
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(model_json, f, indent=4, ensure_ascii=False)

        print(f"✅ XGBoost模型已转换为JSON格式:{output_path}")
        return True
    except Exception as e:
        print(f"❌ XGBoost模型转JSON失败:{e}")
        return False


# ====================== 工具函数 ======================
def is_ip_address(s):
    """判断字符串是否为IP地址"""
    if not isinstance(s, str):
        return False
    ip_pattern = r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$'
    return re.match(ip_pattern, s) is not None


def is_numeric_column(series):
    """判断列是否为数值类型(可以安全转换为数字)"""
    # 先移除空值
    non_null_series = series.dropna()
    if len(non_null_series) == 0:
        return False

    # 检查是否包含IP地址
    if non_null_series.apply(lambda x: is_ip_address(x) if isinstance(x, str) else False).any():
        return False

    # 尝试转换为数值
    try:
        pd.to_numeric(non_null_series, errors='raise')
        return True
    except (ValueError, TypeError):
        return False


def get_verify_type_config(verify_type: str) -> dict:
    """获取指定verify_type的配置"""
    return VERIFY_TYPE_CONFIGS.get(verify_type, VERIFY_TYPE_CONFIGS['default'])


def init_realtime_plot(verify_type: str):
    """初始化指定verify_type的实时绘图画布"""
    if train_metrics[verify_type]['fig'] is None:
        # 获取该类型的核心特征
        config = get_verify_type_config(verify_type)
        core_features = config['core_features']

        # 调整布局
        fig = plt.figure(figsize=(15, 10))
        fig.suptitle(f'模型训练监控 - Verify Type: {verify_type.upper()}', fontsize=16)

        # 子图1:LogLoss
        ax1 = plt.subplot(2, 2, 1)
        ax1.set_xlabel('迭代次数')
        ax1.set_ylabel('LogLoss')
        ax1.set_title('训练/验证集LogLoss(越小越好)')
        ax1.grid(True, alpha=0.3)
        line1, = ax1.plot([], [], label='训练集', color='blue', linewidth=1.5)
        line2, = ax1.plot([], [], label='验证集', color='red', linewidth=1.5)
        ax1.legend()

        # 子图2:AUC
        ax2 = plt.subplot(2, 2, 2)
        ax2.set_xlabel('迭代次数')
        ax2.set_ylabel('AUC')
        ax2.set_title('训练/验证集AUC(越大越好)')
        ax2.grid(True, alpha=0.3)
        line3, = ax2.plot([], [], label='训练集', color='blue', linewidth=1.5)
        line4, = ax2.plot([], [], label='验证集', color='red', linewidth=1.5)
        ax2.legend()

        # 子图3:核心特征重要性趋势
        ax3 = plt.subplot(2, 2, 3)
        ax3.set_xlabel('迭代次数')
        ax3.set_ylabel('特征重要性(Gain)')
        ax3.set_title(f'核心特征重要性趋势({"/".join(core_features)})')
        ax3.grid(True, alpha=0.3)

        # 为每个核心特征创建线条
        lines = {}
        colors = ['green', 'orange', 'purple', 'brown']
        for i, feat in enumerate(core_features):
            line, = ax3.plot([], [], label=feat, color=colors[i % len(colors)], linewidth=1.5)
            lines[feat] = line
        ax3.legend()

        # 子图4:核心特征分布对比
        ax4 = plt.subplot(2, 2, 4)
        ax4.set_title('核心特征分布(正常vs异常)')
        ax4.set_visible(False)

        # 保存到全局变量
        train_metrics[verify_type]['fig'] = fig
        train_metrics[verify_type]['ax1'] = ax1
        train_metrics[verify_type]['ax2'] = ax2
        train_metrics[verify_type]['ax3'] = ax3
        train_metrics[verify_type]['ax4'] = ax4
        train_metrics[verify_type]['lines']['train_loss'] = line1
        train_metrics[verify_type]['lines']['val_loss'] = line2
        train_metrics[verify_type]['lines']['train_auc'] = line3
        train_metrics[verify_type]['lines']['val_auc'] = line4
        for feat, line in lines.items():
            train_metrics[verify_type]['lines'][feat] = line


def update_realtime_plot(verify_type: str, iteration, train_loss, val_loss, train_auc, val_auc, model=None):
    """更新指定verify_type的实时绘图数据"""
    metrics = train_metrics[verify_type]
    config = get_verify_type_config(verify_type)
    core_features = config['core_features']

    # 添加基础指标
    metrics['train_loss'].append(train_loss)
    metrics['val_loss'].append(val_loss)
    metrics['train_auc'].append(train_auc)
    metrics['val_auc'].append(val_auc)

    # 更新基础线条
    x_data = list(range(1, len(metrics['train_loss']) + 1))

    if len(x_data) == len(metrics['train_loss']):
        metrics['lines']['train_loss'].set_data(x_data, metrics['train_loss'])
    if len(x_data) == len(metrics['val_loss']):
        metrics['lines']['val_loss'].set_data(x_data, metrics['val_loss'])
    if len(x_data) == len(metrics['train_auc']):
        metrics['lines']['train_auc'].set_data(x_data, metrics['train_auc'])
    if len(x_data) == len(metrics['val_auc']):
        metrics['lines']['val_auc'].set_data(x_data, metrics['val_auc'])

    # 更新核心特征重要性
    if model and iteration % 10 == 0:
        try:
            importance_dict = model.get_score(importance_type='gain')
            metrics['importance_iterations'].append(iteration)

            for feat in core_features:
                imp_value = importance_dict.get(feat, 0)
                metrics['core_feature_importance'][feat].append(imp_value)

            # 绘制核心特征重要性
            imp_iter = metrics['importance_iterations']
            for feat in core_features:
                imp_vals = metrics['core_feature_importance'][feat]
                if len(imp_iter) == len(imp_vals) and len(imp_iter) > 0:
                    metrics['lines'][feat].set_data(imp_iter, imp_vals)

            metrics['ax3'].relim()
            metrics['ax3'].autoscale_view()

        except Exception as e:
            print(f"\n⚠️ {verify_type} 特征重要性绘图更新失败:{e}")

    # 自动调整坐标轴
    metrics['ax1'].relim()
    metrics['ax1'].autoscale_view()
    metrics['ax2'].relim()
    metrics['ax2'].autoscale_view()

    # 刷新画布
    try:
        metrics['fig'].canvas.draw()
        metrics['fig'].canvas.flush_events()
    except Exception as e:
        print(f"\n⚠️ {verify_type} 绘图刷新失败:{e}")

    time.sleep(0.01)


# ====================== 差异化回调类 ======================
class XGBoostRealtimePlotCallback(TrainingCallback):
    """差异化的XGBoost实时绘图回调类"""

    def __init__(self, verify_type: str):
        super().__init__()
        self.iteration = 0
        self.verify_type = verify_type

    def after_iteration(self, model, epoch, evals_log):
        """每轮迭代后执行的回调方法"""
        self.iteration = epoch

        try:
            # 提取训练和验证指标
            train_loss = val_loss = train_auc = val_auc = None

            if 'train' in evals_log and 'logloss' in evals_log['train']:
                train_loss = evals_log['train']['logloss'][-1]
            if 'eval' in evals_log and 'logloss' in evals_log['eval']:
                val_loss = evals_log['eval']['logloss'][-1]
            if 'train' in evals_log and 'auc' in evals_log['train']:
                train_auc = evals_log['train']['auc'][-1]
            if 'eval' in evals_log and 'auc' in evals_log['eval']:
                val_auc = evals_log['eval']['auc'][-1]

            # 更新绘图
            if all([train_loss, val_loss, train_auc, val_auc]):
                update_realtime_plot(self.verify_type, self.iteration,
                                     train_loss, val_loss, train_auc, val_auc, model)

        except Exception as e:
            print(f"\n⚠️ {self.verify_type} 回调函数执行失败(迭代{epoch}):{e}")

        return False


# ====================== 数据处理函数(修复版本) ======================
def load_and_validate_data(csv_file: str) -> pd.DataFrame:
    """加载并校验验证码行为数据"""
    try:
        df = pd.read_csv(csv_file, encoding='utf-8')
        print(f"✅ 数据加载成功,形状: {df.shape}")
        print(f"✅ verify_type分布:\n{df['verify_type'].value_counts()}")

        # 检查关键列
        required_cols = [
            'label', 'verify_type', 'load_to_first_delay', 'total_duration'
        ]
        missing_cols = [col for col in required_cols if col not in df.columns]
        if missing_cols:
            raise ValueError(f"❌ 缺少关键列:{missing_cols}")

        # 过滤无效的verify_type
        valid_verify_types = ['swipe', 'click']
        df = df[df['verify_type'].isin(valid_verify_types)]
        print(f"✅ 过滤后有效数据量:{len(df)}(仅保留swipe/click)")

        # 提前过滤无用列(修复核心问题)
        df = filter_useless_columns(df)

        return df
    except Exception as e:
        raise Exception(f"❌ 加载数据失败:{str(e)}")


def filter_useless_columns(df: pd.DataFrame) -> pd.DataFrame:
    """过滤无用列(包含字符串ID、IP、时间等)"""
    df_filtered = df.copy()

    # 找出所有需要删除的列
    cols_to_drop = [col for col in USELESS_COLS if col in df_filtered.columns]

    if cols_to_drop:
        df_filtered = df_filtered.drop(columns=cols_to_drop)
        print(f"\n🗑️ 已删除无用列:{cols_to_drop}")

    # 额外检查:识别并移除包含IP地址的列
    ip_columns = []
    for col in df_filtered.columns:
        # 检查列名是否包含ip相关关键词
        if 'ip' in col.lower() and col != 'ip_type' and col != 'ip_verify_freq' and col != 'device_ip_bind_count':
            # 抽样检查列内容是否包含IP地址
            sample_data = df_filtered[col].dropna().head(10)
            if sample_data.apply(lambda x: is_ip_address(x) if isinstance(x, str) else False).any():
                ip_columns.append(col)

    if ip_columns:
        df_filtered = df_filtered.drop(columns=ip_columns)
        print(f"🗑️ 已删除包含IP地址的列:{ip_columns}")

    return df_filtered


def filter_extreme_samples(df: pd.DataFrame, verify_type: str, cols: list) -> pd.DataFrame:
    """差异化过滤极端值样本"""
    config = get_verify_type_config(verify_type)
    lower_q, upper_q = config['extreme_quantile']

    df_filtered = df.copy()
    total_samples = len(df_filtered)

    # 初始化该类型的预处理参数
    preprocessing_params[verify_type] = {}

    for col in cols:
        if col not in df_filtered.columns:
            print(f"⚠️ {verify_type} 列{col}不存在,跳过过滤")
            continue

        # 确保列是数值类型
        if not is_numeric_column(df_filtered[col]):
            print(f"⚠️ {verify_type} 列{col}不是数值类型,跳过过滤")
            continue

        q_low = df_filtered[col].quantile(lower_q)
        q_high = df_filtered[col].quantile(upper_q)

        # 保存分位数参数
        preprocessing_params[verify_type][col] = {
            'q_low': q_low, 'q_high': q_high,
            'mean': df_filtered[col].mean(), 'std': df_filtered[col].std(),
            'median': df_filtered[col].median()  # 新增中位数保存
        }

        # 过滤条件
        filter_condition = (df_filtered[col] >= q_low) & (df_filtered[col] <= q_high)
        df_filtered = df_filtered[filter_condition]

        filtered_count = total_samples - len(df_filtered)
        total_samples = len(df_filtered)

        print(f"\n🔧 [{verify_type}] {col} 极端值过滤:")
        print(f"   分位数范围:{q_low:.2f} ~ {q_high:.2f}")
        print(f"   过滤样本数:{filtered_count},剩余样本数:{total_samples}")

    return df_filtered


def process_labels(df: pd.DataFrame) -> pd.DataFrame:
    """处理验证码标签"""
    print(f"\n📊 原始标签分布:")
    label_counts = df['label'].value_counts()
    print(label_counts)

    label_mapping = {'normal': 0, 'risk': 1, 'abnormal': 1}
    df['binary_label'] = df['label'].map(lambda x: label_mapping.get(x, 1))

    print(f"\n📊 转换后的二进制标签分布:")
    print(df['binary_label'].value_counts())

    return df


def select_features(df: pd.DataFrame, verify_type: str) -> Tuple[List[str], List[str], List[str]]:
    """差异化选择特征(修复版本)"""
    config = get_verify_type_config(verify_type)
    core_features = config['core_features']

    # 排除标签列和verify_type
    feature_cols = [col for col in df.columns if col not in ['label', 'binary_label', 'verify_type']]

    # 自动识别分类特征和数值特征
    categorical_cols = []
    numeric_cols = []

    for col in feature_cols:
        # 检查是否为数值列
        if is_numeric_column(df[col]):
            numeric_cols.append(col)
        else:
            categorical_cols.append(col)

    print(f"\n🔧 [{verify_type}] 自动识别特征类型:")
    print(f"   分类特征: {categorical_cols}")
    # 标记核心特征
    marked_numeric = [f"⭐{feat}" if feat in core_features else feat for feat in numeric_cols]
    print(f"   数值特征: {marked_numeric}")

    # 检查核心特征是否存在
    missing_core = [feat for feat in core_features if feat not in numeric_cols]
    if missing_core:
        print(f"⚠️ [{verify_type}] 核心特征缺失:{missing_core}")

    return feature_cols, categorical_cols, numeric_cols


def fix_numeric_outliers(df: pd.DataFrame, verify_type: str, numeric_cols: list) -> pd.DataFrame:
    """差异化修复数值特征异常值"""
    config = get_verify_type_config(verify_type)
    core_features = config['core_features']

    df_fixed = df.copy()

    # 处理核心特征
    for col in core_features:
        if col in numeric_cols and col in df_fixed.columns:
            # 确保列是数值类型
            df_fixed[col] = pd.to_numeric(df_fixed[col], errors='coerce')

            # 负数替换为0
            neg_count = (df_fixed[col] < 0).sum()
            if neg_count > 0:
                df_fixed.loc[df_fixed[col] < 0, col] = 0
                print(f"\n🔧 [{verify_type}] 修复{col}负数:{neg_count} 个值")

    # 处理其他数值特征
    other_numeric = [col for col in numeric_cols if col not in core_features and col in df_fixed.columns]
    for col in other_numeric:
        # 确保列是数值类型
        df_fixed[col] = pd.to_numeric(df_fixed[col], errors='coerce')

        Q1 = df_fixed[col].quantile(0.25)
        Q3 = df_fixed[col].quantile(0.75)
        IQR = Q3 - Q1
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR

        outlier_mask = (df_fixed[col] < lower_bound) | (df_fixed[col] > upper_bound)
        outlier_count = outlier_mask.sum()

        if outlier_count > 0:
            df_fixed.loc[outlier_mask, col] = df_fixed[col].clip(lower_bound, upper_bound)
            print(f"🔧 [{verify_type}] 修复{col}异常值:{outlier_count} 个值")

    return df_fixed


def handle_missing_values(df: pd.DataFrame, verify_type: str, numeric_cols: list,
                          categorical_cols: list) -> pd.DataFrame:
    """差异化填充缺失值(修复版本)"""
    config = get_verify_type_config(verify_type)
    core_features = config['core_features']

    df_filled = df.copy()

    # 记录数值特征的中位数(用于Go端填充)
    numeric_medians = {}
    # 记录分类特征的众数(用于Go端填充)
    categorical_modes = {}

    # 填充核心特征
    for col in core_features:
        if col in numeric_cols and col in df_filled.columns:
            # 确保列是数值类型
            df_filled[col] = pd.to_numeric(df_filled[col], errors='coerce')
            numeric_medians[col] = df_filled[col].median()

            missing_count = df_filled[col].isnull().sum()
            if missing_count > 0:
                if 'delay' in col or 'duration' in col:
                    fill_val = df_filled[col].median()
                    df_filled[col] = df_filled[col].fillna(fill_val)
                    print(f"\n🔧 [{verify_type}] 填充{col}缺失值:{missing_count} 个值,中位数={fill_val:.2f}")
                else:
                    fill_val = 0 if 'freq' in col else df_filled[col].median()
                    df_filled[col] = df_filled[col].fillna(fill_val)
                    print(f"\n🔧 [{verify_type}] 填充{col}缺失值:{missing_count} 个值,默认值={fill_val:.2f}")

    # 填充其他数值特征
    for col in [c for c in numeric_cols if c not in core_features and c in df_filled.columns]:
        df_filled[col] = pd.to_numeric(df_filled[col], errors='coerce')
        numeric_medians[col] = df_filled[col].median()
        missing_count = df_filled[col].isnull().sum()
        if missing_count > 0:
            fill_val = df_filled[col].median()
            df_filled[col] = df_filled[col].fillna(fill_val)

    # 填充分类特征
    for col in categorical_cols:
        if col in df_filled.columns:
            if df_filled[col].isnull().sum() > 0:
                mode_val = df_filled[col].mode()[0]
                categorical_modes[col] = mode_val
                df_filled[col] = df_filled[col].fillna(mode_val)
        else:
            categorical_modes[col] = 'Unknown'

    # 保存中位数和众数到预处理参数
    preprocessing_params[verify_type]['numeric_medians'] = numeric_medians
    preprocessing_params[verify_type]['categorical_modes'] = categorical_modes

    return df_filled


def encode_categorical_features(df: pd.DataFrame, categorical_cols: list) -> Tuple[
    pd.DataFrame, Dict[str, LabelEncoder]]:
    """编码分类特征"""
    label_encoders = {}
    df_encoded = df.copy()

    for col in categorical_cols:
        if col in df_encoded.columns:
            le = LabelEncoder()
            # 填充未知值为'Unknown'
            df_encoded[col] = df_encoded[col].fillna('Unknown').astype(str)
            # 确保'Unknown'在类别中
            if 'Unknown' not in df_encoded[col].unique():
                df_encoded[col] = df_encoded[col].cat.add_categories('Unknown') if df_encoded[
                                                                                       col].dtype.name == 'category' else \
                df_encoded[col]
            df_encoded[col] = le.fit_transform(df_encoded[col])
            label_encoders[col] = le
            print(f"\n🔤 编码 {col}:原始值={le.classes_[:10]}...(显示前10个)")

    return df_encoded, label_encoders


def balance_samples(X: pd.DataFrame, y: pd.Series, verify_type: str) -> Tuple[pd.DataFrame, pd.Series]:
    """差异化样本平衡"""
    config = get_verify_type_config(verify_type)
    smote_strategy = config['smote_strategy']

    if y.nunique() < 2:
        print(f"\n⚠️ [{verify_type}] 标签只有1类,跳过SMOTE采样")
        return X, y

    pos_ratio = y.sum() / len(y)
    print(f"\n📊 [{verify_type}] 原始样本分布:异常样本占比 {pos_ratio:.2%}")

    if pos_ratio < smote_strategy:
        try:
            smote = SMOTE(random_state=42, sampling_strategy=smote_strategy)
            X_resampled, y_resampled = smote.fit_resample(X, y)
            new_pos_ratio = y_resampled.sum() / len(y_resampled)
            print(f"✅ [{verify_type}] 采样后异常样本占比 {new_pos_ratio:.2%},样本量 {len(X_resampled)}")
            return X_resampled, y_resampled
        except Exception as e:
            print(f"❌ [{verify_type}] SMOTE采样失败:{e}")

    return X, y


def create_core_feature_interactions(df: pd.DataFrame, verify_type: str) -> pd.DataFrame:
    """差异化创建核心特征交互项"""
    config = get_verify_type_config(verify_type)
    core_features = config['core_features']

    df_interact = df.copy()

    # 根据不同verify_type创建差异化的交互特征
    if verify_type == 'swipe':
        # 滑动验证:关注滑动相关交互
        if all(feat in df_interact.columns for feat in ['load_to_first_delay', 'avg_swipe_interval']):
            df_interact['load_swipe_ratio'] = df_interact['load_to_first_delay'] / (
                    df_interact['avg_swipe_interval'] + 1e-6)

    elif verify_type == 'click':
        # 点击验证:关注点击相关交互
        if all(feat in df_interact.columns for feat in ['total_duration', 'click_interval_cv']):
            df_interact['duration_click_cv'] = df_interact['total_duration'] * df_interact['click_interval_cv']

    # 通用交互特征
    if all(feat in df_interact.columns for feat in ['load_to_first_delay', 'total_duration']):
        df_interact['load_to_total_ratio'] = df_interact['load_to_first_delay'] / (df_interact['total_duration'] + 1e-6)

    print(f"\n🔧 [{verify_type}] 创建核心特征交互项完成")
    return df_interact


# ====================== 模型训练函数(差异化版本) ======================
def tune_hyperparameters(X_train, y_train, X_val, y_val, verify_type: str):
    """差异化调参"""
    config = get_verify_type_config(verify_type)
    base_params = config['model_params']

    def objective(trial):
        params = {
            'booster': 'gbtree',
            'objective': 'binary:logistic',
            'eval_metric': ['logloss', 'auc'],
            'max_depth': trial.suggest_int('max_depth', base_params['max_depth'] - 1, base_params['max_depth'] + 1),
            'learning_rate': trial.suggest_float('learning_rate', base_params['learning_rate'] * 0.8,
                                                 base_params['learning_rate'] * 1.2, log=True),
            'subsample': trial.suggest_float('subsample', base_params['subsample'] - 0.05,
                                             base_params['subsample'] + 0.05),
            'colsample_bytree': trial.suggest_float('colsample_bytree', base_params['colsample_bytree'] - 0.05,
                                                    base_params['colsample_bytree'] + 0.05),
            'reg_alpha': trial.suggest_float('reg_alpha', base_params['reg_alpha'] * 0.8,
                                             base_params['reg_alpha'] * 1.2),
            'reg_lambda': trial.suggest_float('reg_lambda', base_params['reg_lambda'] * 0.8,
                                              base_params['reg_lambda'] * 1.2),
            'random_state': 42,
            'verbosity': 0
        }

        dtrain = xgb.DMatrix(X_train, label=y_train)
        dval = xgb.DMatrix(X_val, label=y_val)

        model = xgb.train(
            params, dtrain, num_boost_round=100,
            evals=[(dval, 'eval')], early_stopping_rounds=15, verbose_eval=False
        )
        return model.best_score

    study = optuna.create_study(direction='maximize', study_name=f'{verify_type}_xgb_tune')
    study.optimize(objective, n_trials=5)
    print(f"\n🏆 [{verify_type}] 最优参数:{study.best_params}")
    print(f"🏆 [{verify_type}] 最优验证AUC:{study.best_value:.4f}")
    return study.best_params


def train_xgboost_model(X_train, y_train, X_test, y_test, verify_type: str, use_optuna: bool = True):
    """差异化训练XGBoost模型"""
    # 初始化绘图
    init_realtime_plot(verify_type)

    # 获取差异化配置
    config = get_verify_type_config(verify_type)
    model_params = config['model_params']

    # 基础参数
    base_params = {
        'booster': 'gbtree',
        'objective': 'binary:logistic',
        'eval_metric': ['logloss', 'auc'],
        'verbosity': 1,
        'random_state': 42,
        'scale_pos_weight': 1,
        'importance_type': 'gain',
        **model_params
    }

    # 调参
    if use_optuna and y_train.nunique() >= 2:
        X_train_tune, X_val_tune, y_train_tune, y_val_tune = train_test_split(
            X_train, y_train, test_size=0.2, random_state=42, stratify=y_train
        )
        best_params = tune_hyperparameters(X_train_tune, y_train_tune, X_val_tune, y_val_tune, verify_type)
        base_params = {**base_params, **best_params}

    # 准备数据
    dtrain = xgb.DMatrix(X_train, label=y_train)
    dtest = xgb.DMatrix(X_test, label=y_test)

    # 保存特征列
    preprocessing_params[verify_type]['feature_cols'] = X_train.columns.tolist()

    # 创建回调
    plot_callback = XGBoostRealtimePlotCallback(verify_type)

    # 训练模型
    print(f"\n🚀 开始训练 {verify_type} 类型验证码模型...")
    evals_result = {}
    try:
        model = xgb.train(
            base_params,
            dtrain,
            num_boost_round=base_params['n_estimators'],
            evals=[(dtrain, 'train'), (dtest, 'eval')],
            evals_result=evals_result,
            early_stopping_rounds=20,
            verbose_eval=10,
            callbacks=[plot_callback]
        )
    except Exception as e:
        print(f"\n⚠️ [{verify_type}] 训练异常:{e}")
        # 重试(无回调)
        model = xgb.train(
            base_params,
            dtrain,
            num_boost_round=base_params['n_estimators'],
            evals=[(dtrain, 'train'), (dtest, 'eval')],
            evals_result=evals_result,
            early_stopping_rounds=20,
            verbose_eval=10
        )

    # 绘制特征分布
    try:
        if y_test.nunique() >= 2:
            ax4 = train_metrics[verify_type]['ax4']
            ax4.clear()
            ax4.set_title(f'{verify_type} 核心特征分布(正常vs异常)')

            X_test_df = pd.DataFrame(X_test, columns=X_train.columns)
            normal_mask = y_test == 0
            abnormal_mask = y_test == 1

            core_features = config['core_features']
            positions = list(range(1, len(core_features) + 1))
            has_data = False

            for i, feat in enumerate(core_features):
                if feat in X_test_df.columns:
                    normal_data = X_test_df.loc[normal_mask, feat].values
                    abnormal_data = X_test_df.loc[abnormal_mask, feat].values

                    if len(normal_data) > 0 and len(abnormal_data) > 0:
                        has_data = True
                        ax4.boxplot(normal_data, positions=[positions[i] - 0.15], widths=0.2,
                                    patch_artist=True, boxprops=dict(facecolor='green', alpha=0.5),
                                    label='正常' if i == 0 else "")
                        ax4.boxplot(abnormal_data, positions=[positions[i] + 0.15], widths=0.2,
                                    patch_artist=True, boxprops=dict(facecolor='red', alpha=0.5),
                                    label='异常' if i == 0 else "")

            if has_data:
                ax4.set_xticks(positions)
                ax4.set_xticklabels(core_features)
                ax4.legend()
                ax4.grid(True, alpha=0.3)
            ax4.set_visible(True)
    except Exception as e:
        print(f"\n⚠️ [{verify_type}] 特征分布绘图失败:{e}")

    return model, evals_result, base_params


def evaluate_model(model, X_test, y_test, feature_cols, verify_type: str):
    """差异化评估模型"""
    config = get_verify_type_config(verify_type)
    core_features = config['core_features']

    dtest = xgb.DMatrix(X_test)
    y_pred_prob = model.predict(dtest)
    y_pred = (y_pred_prob > 0.5).astype(int)

    # 计算指标
    accuracy = accuracy_score(y_test, y_pred)
    auc = roc_auc_score(y_test, y_pred_prob) if y_test.nunique() >= 2 else -1

    print(f"\n📈 [{verify_type}] 模型评估结果:")
    print(f"🏆 最佳迭代次数:{model.best_iteration}")
    print(f"🎯 准确率:{accuracy:.4f}")
    print(f"📊 AUC:{auc:.4f}" if auc != -1 else "📊 AUC:无意义")

    # 特征重要性
    importance_dict = model.get_score(importance_type='gain')
    all_importance = {feat: importance_dict.get(feat, 0) for feat in feature_cols}
    importance_df = pd.DataFrame({
        'feature': list(all_importance.keys()),
        'importance': list(all_importance.values())
    }).sort_values('importance', ascending=False)

    # 标记核心特征
    importance_df['is_core'] = importance_df['feature'].apply(lambda x: '⭐' if x in core_features else '')

    # 核心特征贡献
    total_importance = importance_df['importance'].sum()
    core_total = importance_df[importance_df['feature'].isin(core_features)]['importance'].sum()
    core_contribution = (core_total / total_importance * 100) if total_importance > 0 else 0

    print(f"\n🎯 [{verify_type}] 核心特征总贡献:{core_contribution:.1f}%")
    print(f"\n✨ [{verify_type}] 特征重要性(前10):")
    print(importance_df.head(10)[['is_core', 'feature', 'importance']])

    return {
        'accuracy': accuracy, 'auc': auc,
        'feature_importance': importance_df,
        'core_feature_contribution': core_contribution,
        'best_iteration': model.best_iteration,
        'feature_importance_dict': all_importance  # 新增:便于JSON序列化
    }


def save_results(models: dict, label_encoders: dict, feature_cols: dict,
                 eval_results: dict, params: dict, save_dir='./captcha_results'):
    """保存所有模型结果(包含JSON格式)"""
    # 创建目录
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs('./captcha_models', exist_ok=True)

    # 保存每个类型的模型
    for verify_type in ['swipe', 'click']:
        if verify_type in models:
            # 保存模型二进制文件
            model_bin_path = f'./captcha_models/captcha_xgb_model_{verify_type}.bin'
            models[verify_type].save_model(model_bin_path)
            print(f"\n💾 {verify_type} 模型已保存到:{model_bin_path}")

            # ========== 新增:转换模型为JSON格式 ==========
            model_json_path = f'./captcha_models/captcha_xgb_model_{verify_type}.json'
            convert_xgb_model_to_json(models[verify_type], model_json_path)

            # 保存预处理参数(pickle格式)
            preprocess_path = f'./captcha_models/preprocessing_params_{verify_type}.pkl'
            with open(preprocess_path, 'wb') as f:
                pickle.dump(preprocessing_params[verify_type], f)

            # ========== 完善JSON配置文件 ==========
            # 1. 准备JSON序列化的数据
            json_data = {
                'verify_type': verify_type,
                'model_config': params.get(verify_type, {}),
                'preprocessing_params': preprocessing_params.get(verify_type, {}),
                'evaluation_metrics': {
                    'accuracy': float(eval_results[verify_type]['accuracy']),
                    'auc': float(eval_results[verify_type]['auc']),
                    'core_feature_contribution': float(eval_results[verify_type]['core_feature_contribution']),
                    'best_iteration': int(eval_results[verify_type]['best_iteration'])
                },
                'feature_importance': eval_results[verify_type]['feature_importance_dict'],
                'core_features': get_verify_type_config(verify_type)['core_features'],
                'feature_columns': feature_cols.get(verify_type, []),
                'training_timestamp': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
                'numeric_medians': preprocessing_params[verify_type].get('numeric_medians', {}),
                'categorical_modes': preprocessing_params[verify_type].get('categorical_modes', {})
            }

            # 2. 处理LabelEncoder(只保存类别映射)
            le_json = {}
            if verify_type in label_encoders:
                for col, le in label_encoders[verify_type].items():
                    le_json[col] = {
                        'classes': le.classes_.tolist(),
                        'mapping': {str(cls): int(idx) for idx, cls in enumerate(le.classes_)},
                        'num_classes': len(le.classes_)
                    }
            json_data['label_encoders'] = le_json

            # 3. 保存JSON配置文件
            json_config_path = f'./captcha_models/captcha_model_{verify_type}.json'
            with open(json_config_path, 'w', encoding='utf-8') as f:
                json.dump(json_data, f, ensure_ascii=False, indent=4)
            print(f"💾 {verify_type} 模型JSON配置文件已保存到:{json_config_path}")

            # 保存评估结果(CSV格式)
            eval_df = pd.DataFrame({
                'metric': ['accuracy', 'auc', 'core_feature_contribution', 'best_iteration'],
                'value': [
                    eval_results[verify_type]['accuracy'],
                    eval_results[verify_type]['auc'],
                    eval_results[verify_type]['core_feature_contribution'],
                    eval_results[verify_type]['best_iteration']
                ]
            })
            eval_df.to_csv(f'{save_dir}/evaluation_results_{verify_type}.csv', index=False)

    # 保存全局配置JSON文件
    global_config = {
        'verify_type_configs': VERIFY_TYPE_CONFIGS,
        'useless_columns': USELESS_COLS,
        'training_timestamp': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    }
    global_config_path = './captcha_models/global_config.json'
    with open(global_config_path, 'w', encoding='utf-8') as f:
        json.dump(global_config, f, ensure_ascii=False, indent=4)
    print(f"💾 全局配置JSON文件已保存到:{global_config_path}")

    # 保存标签编码器(pickle格式)
    le_path = './captcha_models/label_encoders.pkl'
    with open(le_path, 'wb') as f:
        pickle.dump(label_encoders, f)

    print(f"\n✅ 所有模型和配置已保存完成!")


# ====================== 统一预测接口(兼容JSON配置) ======================
def predict_captcha(input_data: dict):
    """
    统一的预测接口,自动识别verify_type并使用对应模型
    """
    # 获取验证类型
    verify_type = input_data.get('verify_type', 'default').lower()
    if verify_type not in ['swipe', 'click']:
        verify_type = 'default'

    # 加载对应模型
    model_path = f'./captcha_models/captcha_xgb_model_{verify_type}.bin'
    preprocess_path = f'./captcha_models/preprocessing_params_{verify_type}.pkl'
    json_config_path = f'./captcha_models/captcha_model_{verify_type}.json'

    if not os.path.exists(model_path):
        raise Exception(f"❌ 未找到 {verify_type} 类型的模型文件")

    # 加载模型
    model = xgb.Booster()
    model.load_model(model_path)

    # 优先从JSON加载配置
    if os.path.exists(json_config_path):
        with open(json_config_path, 'r', encoding='utf-8') as f:
            json_config = json.load(f)
        preprocess_params = json_config.get('preprocessing_params', {})
        feature_cols = json_config.get('feature_columns', [])
        label_encoders_json = json_config.get('label_encoders', {})
    else:
        # 兼容旧的pickle格式
        with open(preprocess_path, 'rb') as f:
            preprocess_params = pickle.load(f)
        feature_cols = preprocess_params.get('feature_cols', [])
        le_path = './captcha_models/label_encoders.pkl'
        with open(le_path, 'rb') as f:
            label_encoders = pickle.load(f)
        label_encoders_json = {}

    # 数据预处理(与训练时一致)
    df = pd.DataFrame([input_data])

    # 过滤无用列
    df = filter_useless_columns(df)

    # 处理数值特征
    for col in df.columns:
        if df[col].dtype in ['int64', 'float64'] and df[col].iloc[0] < 0:
            df[col] = 0

    # 编码分类特征(兼容JSON格式的编码器配置)
    for col in ['os', 'device_type', 'ip_type']:
        if col in df.columns and col in label_encoders_json:
            le_mapping = label_encoders_json[col]['mapping']
            df[col] = df[col].fillna('Unknown').astype(str)
            # 处理未知类别
            df[col] = df[col].apply(lambda x: le_mapping.get(x, le_mapping.get('Unknown', 0)))

    # 只保留训练时的特征列
    for col in feature_cols:
        if col not in df.columns:
            df[col] = 0
    df = df[feature_cols]

    # 预测
    dmatrix = xgb.DMatrix(df)
    pred_prob = model.predict(dmatrix)[0]
    pred_label = 1 if pred_prob > 0.5 else 0

    return {
        'verify_type': verify_type,
        'abnormal_probability': float(pred_prob),
        'predicted_label': int(pred_label),
        'label_description': '异常/疑似' if pred_label == 1 else '正常',
        'prediction_timestamp': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    }


# ====================== 主函数 ======================
def main():
    """主函数:差异化训练不同verify_type的模型"""
    # 配置
    CSV_FILE_PATH = './captcha_verify_feature.csv'
    USE_OPTUNA = False

    try:
        # 1. 加载数据
        df = load_and_validate_data(CSV_FILE_PATH)

        # 2. 按verify_type拆分数据
        swipe_df = df[df['verify_type'] == 'swipe'].copy()
        click_df = df[df['verify_type'] == 'click'].copy()

        # 存储结果的字典
        models = {}
        eval_results = {}
        all_params = {}
        label_encoders = {}
        feature_cols_dict = {}

        # 3. 训练swipe模型
        if len(swipe_df) > 0:
            print("\n" + "=" * 60)
            print("开始训练 SWIPE (滑动) 验证模型")
            print("=" * 60)

            # 数据预处理
            swipe_df = process_labels(swipe_df)
            swipe_df = filter_extreme_samples(swipe_df, 'swipe', ['load_to_first_delay', 'total_duration'])

            # 特征处理
            feature_cols, categorical_cols, numeric_cols = select_features(swipe_df, 'swipe')
            swipe_df = fix_numeric_outliers(swipe_df, 'swipe', numeric_cols)
            swipe_df = handle_missing_values(swipe_df, 'swipe', numeric_cols, categorical_cols)
            swipe_df = create_core_feature_interactions(swipe_df, 'swipe')

            # 更新特征列表
            feature_cols = [col for col in swipe_df.columns if col not in ['label', 'binary_label', 'verify_type']]
            swipe_df_encoded, le = encode_categorical_features(swipe_df, categorical_cols)
            label_encoders['swipe'] = le

            # 分离特征和标签
            X = swipe_df_encoded[feature_cols]
            y = swipe_df_encoded['binary_label']

            # 平衡样本
            X_balanced, y_balanced = balance_samples(X, y, 'swipe')

            # 划分数据集
            X_train, X_test, y_train, y_test = train_test_split(
                X_balanced, y_balanced, test_size=0.2, random_state=42,
                stratify=y_balanced if y_balanced.nunique() >= 2 else None
            )

            # 训练模型
            model, evals_result, params = train_xgboost_model(
                X_train, y_train, X_test, y_test, 'swipe', USE_OPTUNA
            )

            # 评估模型
            eval_res = evaluate_model(model, X_test, y_test, feature_cols, 'swipe')

            # 保存结果
            models['swipe'] = model
            eval_results['swipe'] = eval_res
            all_params['swipe'] = params
            feature_cols_dict['swipe'] = feature_cols

        # 4. 训练click模型
        if len(click_df) > 0:
            print("\n" + "=" * 60)
            print("开始训练 CLICK (点击) 验证模型")
            print("=" * 60)

            # 数据预处理
            click_df = process_labels(click_df)
            click_df = filter_extreme_samples(click_df, 'click', ['load_to_first_delay', 'total_duration'])

            # 特征处理
            feature_cols, categorical_cols, numeric_cols = select_features(click_df, 'click')
            click_df = fix_numeric_outliers(click_df, 'click', numeric_cols)
            click_df = handle_missing_values(click_df, 'click', numeric_cols, categorical_cols)
            click_df = create_core_feature_interactions(click_df, 'click')

            # 更新特征列表
            feature_cols = [col for col in click_df.columns if col not in ['label', 'binary_label', 'verify_type']]
            click_df_encoded, le = encode_categorical_features(click_df, categorical_cols)
            label_encoders['click'] = le

            # 分离特征和标签
            X = click_df_encoded[feature_cols]
            y = click_df_encoded['binary_label']

            # 平衡样本
            X_balanced, y_balanced = balance_samples(X, y, 'click')

            # 划分数据集
            X_train, X_test, y_train, y_test = train_test_split(
                X_balanced, y_balanced, test_size=0.2, random_state=42,
                stratify=y_balanced if y_balanced.nunique() >= 2 else None
            )

            # 训练模型
            model, evals_result, params = train_xgboost_model(
                X_train, y_train, X_test, y_test, 'click', USE_OPTUNA
            )

            # 评估模型
            eval_res = evaluate_model(model, X_test, y_test, feature_cols, 'click')

            # 保存结果
            models['click'] = model
            eval_results['click'] = eval_res
            all_params['click'] = params
            feature_cols_dict['click'] = feature_cols

        # 5. 保存所有结果(包含JSON格式)
        save_results(models, label_encoders, feature_cols_dict, eval_results, all_params)

        # 6. 测试预测
        print("\n" + "=" * 60)
        print("测试差异化预测功能")
        print("=" * 60)

        # 测试滑动验证
        test_swipe_data = {
            'verify_type': 'swipe',
            'os': 'Android',
            'device_type': 'Mobile',
            'ip_type': '电信',
            'load_to_first_delay': 20000,
            'total_duration': 30000,
            'avg_swipe_interval': 100,
            'ip_verify_freq': 10
        }

        # 测试点击验证
        test_click_data = {
            'verify_type': 'click',
            'os': 'iOS',
            'device_type': 'Mobile',
            'ip_type': '联通',
            'load_to_first_delay': 15000,
            'total_duration': 25000,
            'click_interval_cv': 0.5,
            'ip_verify_freq': 8
        }

        # 预测(仅当模型存在时)
        if 'swipe' in models:
            swipe_result = predict_captcha(test_swipe_data)
            print(f"\n📱 滑动验证预测结果:")
            print(f"   验证类型:{swipe_result['verify_type']}")
            print(f"   异常概率:{swipe_result['abnormal_probability']:.4f}")
            print(f"   预测结果:{swipe_result['label_description']}")
            print(f"   预测时间:{swipe_result['prediction_timestamp']}")

        if 'click' in models:
            click_result = predict_captcha(test_click_data)
            print(f"\n🖱️  点击验证预测结果:")
            print(f"   验证类型:{click_result['verify_type']}")
            print(f"   异常概率:{click_result['abnormal_probability']:.4f}")
            print(f"   预测结果:{click_result['label_description']}")
            print(f"   预测时间:{click_result['prediction_timestamp']}")

        print("\n🎉 差异化模型训练和预测完成!")

        # 保持绘图窗口
        plt.ioff()
        plt.show()

    except Exception as e:
        print(f"\n❌ 训练失败:{str(e)}")
        import traceback
        traceback.print_exc()
        plt.ioff()
        raise


if __name__ == "__main__":
    main()

推理

/**预测模型**/

var CaptchaPredict *CaptchaPredictor

// ========== 1. 适配Python生成的JSON配置结构体(补充中位数/众数字段) ==========
type ModelJSONConfig struct {
    VerifyType          string                 `json:"verify_type"`
    ModelConfig         map[string]interface{} `json:"model_config"`
    PreprocessingParams map[string]interface{} `json:"preprocessing_params"`
    EvaluationMetrics   struct {
        Accuracy                float64 `json:"accuracy"`
        Auc                     float64 `json:"auc"`
        CoreFeatureContribution float64 `json:"core_feature_contribution"`
        BestIteration           int     `json:"best_iteration"`
    } `json:"evaluation_metrics"`
    FeatureImportance map[string]float64 `json:"feature_importance"`
    CoreFeatures      []string           `json:"core_features"`
    FeatureColumns    []string           `json:"feature_columns"`
    TrainingTimestamp string             `json:"training_timestamp"`
    LabelEncoders     map[string]struct {
        Classes    []string       `json:"classes"`
        Mapping    map[string]int `json:"mapping"`
        NumClasses int            `json:"num_classes"`
    } `json:"label_encoders"`
    NumericMedians   map[string]float64 `json:"numeric_medians"`   // 数值特征中位数(填充缺失值)
    CategoricalModes map[string]string  `json:"categorical_modes"` // 分类特征众数(填充缺失值)
}

// ========== 2. XGBoost模型结构体(适配2.x原生JSON格式) ==========
type XGBoostModel struct {
    Version []int `json:"version"`
    Learner struct {
        Attributes struct {
            BestIteration string `json:"best_iteration"`
            BestScore     string `json:"best_score"`
        } `json:"attributes"`
        FeatureNames    []string `json:"feature_names"`
        FeatureTypes    []string `json:"feature_types"`
        GradientBooster struct {
            Model struct {
                GBTreeModelParam struct {
                    NumParallelTree string `json:"num_parallel_tree"`
                    NumTrees        string `json:"num_trees"`
                } `json:"gbtree_model_param"`
                IterationIndptr []int  `json:"iteration_indptr"`
                TreeInfo        []int  `json:"tree_info"`
                Trees           []Tree `json:"trees"` // 树数组
            } `json:"model"`
            TreeParam struct {
                NumFeature string `json:"num_feature"`
            } `json:"tree_param"`
        } `json:"gradient_booster"`
        LearnerModelParam struct {
            BaseScore  string `json:"base_score"`
            NumClass   string `json:"num_class"`
            NumFeature string `json:"num_feature"`
            NumTarget  string `json:"num_target"`
        } `json:"learner_model_param"`
        Objective struct {
            Name string `json:"name"`
        } `json:"objective"`
    } `json:"learner"`
}

// Tree 适配XGBoost 2.x原生JSON格式的树结构
type Tree struct {
    BaseWeights        []float64 `json:"base_weights"`
    Categories         []string  `json:"categories"`
    CategoriesNodes    []int     `json:"categories_nodes"`
    CategoriesSegments []int     `json:"categories_segments"`
    CategoriesSizes    []int     `json:"categories_sizes"`
    DefaultLeft        []int     `json:"default_left"`
    ID                 int       `json:"id"`
    LeftChildren       []int     `json:"left_children"`
    LossChanges        []float64 `json:"loss_changes"`
    Parents            []int     `json:"parents"`
    RightChildren      []int     `json:"right_children"`
    SplitConditions    []float64 `json:"split_conditions"`
    SplitIndices       []int     `json:"split_indices"`
    SplitType          []int     `json:"split_type"`
    SumHessian         []float64 `json:"sum_hessian"`
    TreeParam          struct {
        NumDeleted     string `json:"num_deleted"`
        NumFeature     string `json:"num_feature"`
        NumNodes       string `json:"num_nodes"`
        SizeLeafVector string `json:"size_leaf_vector"`
    } `json:"tree_param"`
}

// ========== 3. 差异化预测器结构体 ==========
type CaptchaPredictor struct {
    // 模型根路径
    modelRootPath string

    // 加载的模型和配置
    models  map[string]*XGBoostModel
    configs map[string]*ModelJSONConfig

    // 初始化状态
    initialized bool
}

// ========== 4. 预测结果结构体 ==========
type PredictResult struct {
    VerifyType          string  `json:"verify_type"`
    AbnormalProbability float64 `json:"abnormal_probability"`
    PredictedLabel      int     `json:"predicted_label"`
    LabelDescription    string  `json:"label_description"`
    PredictionTimestamp string  `json:"prediction_timestamp"`
}

// ========== 5. 构造函数:创建差异化预测器 ==========
func NewCaptchaPredictor(modelRootPath string) (*CaptchaPredictor, error) {

    predictor := &CaptchaPredictor{
        modelRootPath: modelRootPath,
        models:        make(map[string]*XGBoostModel),
        configs:       make(map[string]*ModelJSONConfig),
        initialized:   false,
    }

    // 加载swipe和click两种模型
    verifyTypes := []string{"swipe", "click"}
    for _, vt := range verifyTypes {
        // 加载JSON配置(Python生成的captcha_model_{vt}.json)
        configPath := fmt.Sprintf("%s/captcha_model_%s.json", modelRootPath, vt)
        config, err := loadModelConfig(configPath)
        if err != nil {
            log.Printf("⚠️ 加载%s配置失败: %v,跳过该模型", vt, err)
            continue
        }
        predictor.configs[vt] = config

        // 加载XGBoost JSON模型(Python转换的captcha_xgb_model_{vt}.json)
        modelPath := fmt.Sprintf("%s/captcha_xgb_model_%s.json", modelRootPath, vt)
        model, err := loadXGBoostModel(modelPath)
        if err != nil {
            log.Printf("⚠️ 加载%s模型失败: %v,跳过该模型", vt, err)
            delete(predictor.configs, vt)
            continue
        }
        predictor.models[vt] = model

        log.Printf("✅ 成功加载%s模型:特征数=%d,决策树数=%d",
            vt, len(config.FeatureColumns), len(model.Learner.GradientBooster.Model.Trees))
    }

    if len(predictor.models) == 0 {
        return nil, fmt.Errorf("未加载到任何验证码模型")
    }

    predictor.initialized = true
    log.Println("✅ 验证码预测器初始化完成")
    return predictor, nil
}

// ========== 6. 内部加载方法 ==========
func loadModelConfig(configPath string) (*ModelJSONConfig, error) {
    data, err := os.ReadFile(configPath)
    if err != nil {
        return nil, err
    }

    var config ModelJSONConfig
    if err := json.Unmarshal(data, &config); err != nil {
        return nil, fmt.Errorf("解析配置失败: %v", err)
    }

    return &config, nil
}

func loadXGBoostModel(modelPath string) (*XGBoostModel, error) {
    data, err := os.ReadFile(modelPath)
    if err != nil {
        return nil, err
    }

    var model XGBoostModel
    if err := json.Unmarshal(data, &model); err != nil {
        return nil, fmt.Errorf("解析模型失败: %v", err)
    }

    return &model, nil
}

// ========== 7. 预处理方法(完全匹配Python逻辑) ==========
func (p *CaptchaPredictor) preprocess(rawData map[string]interface{}, verifyType string) ([]float64, error) {
    if !p.initialized {
        return nil, fmt.Errorf("预测器未初始化")
    }

    // 获取对应配置
    config, ok := p.configs[verifyType]
    if !ok {
        return nil, fmt.Errorf("未找到%s类型的配置", verifyType)
    }

    // 初始化特征数组
    featureCols := config.FeatureColumns
    features := make([]float64, len(featureCols))

    // 构建特征索引映射
    featureIndex := make(map[string]int, len(featureCols))
    for idx, col := range featureCols {
        featureIndex[col] = idx
    }

    // 过滤无用列(完全匹配Python的USELESS_COLS)
    uselessCols := map[string]bool{
        "id":                 true,
        "sample_id":          true,
        "session_id":         true,
        "ip":                 true,
        "device_fingerprint": true,
        "temp_score":         true,
        "created_at":         true,
        "updated_at":         true,
        "user_id":            true,
        "device_id":          true,
        "phone_number":       true,
        "email":              true,
        "cookie_id":          true,
        "session_token":      true,
    }

    for col, idx := range featureIndex {
        // 跳过无用列
        if uselessCols[col] {
            features[idx] = 0.0
            continue
        }

        val, exists := rawData[col]

        // ========== 关键修改1:缺失值填充(匹配Python逻辑) ==========
        if !exists || val == nil {
            // 分类特征:使用众数填充
            if le, ok := config.LabelEncoders[col]; ok {
                modeVal := config.CategoricalModes[col]
                if modeVal == "" {
                    modeVal = "Unknown"
                }
                code, ok := le.Mapping[modeVal]
                if !ok {
                    code = le.Mapping["Unknown"]
                }
                features[idx] = float64(code)
                log.Printf("⚠️ 特征[%s]缺失,填充众数[%s]编码: %d", col, modeVal, code)
                continue
            }

            // 数值特征:使用中位数填充
            medianVal, ok := config.NumericMedians[col]
            if !ok {
                // 核心特征特殊处理
                if strings.Contains(col, "delay") || strings.Contains(col, "duration") {
                    medianVal = 0.0
                } else if strings.Contains(col, "freq") {
                    medianVal = 0.0
                } else {
                    medianVal = 0.0
                }
            }
            features[idx] = medianVal
            //log.Printf("⚠️ 特征[%s]缺失,填充中位数: %.2f", col, medianVal)
            continue
        }

        // ========== 关键修改2:分类特征处理(完全匹配Python) ==========
        if le, ok := config.LabelEncoders[col]; ok {
            var strVal string
            switch v := val.(type) {
            case string:
                strVal = v
            case nil:
                strVal = "Unknown"
            default:
                strVal = fmt.Sprintf("%v", v)
            }

            // 使用编码映射
            code, ok := le.Mapping[strVal]
            if !ok {
                // 未知值使用Unknown的编码
                code = le.Mapping["Unknown"]
                log.Printf("⚠️ 特征[%s]未知值: %s,使用Unknown编码: %d", col, strVal, code)
            }
            features[idx] = float64(code)
            continue
        }

        // ========== 关键修改3:数值特征处理(完全匹配Python) ==========
        var numVal float64
        switch v := val.(type) {
        case float64:
            numVal = v
        case int:
            numVal = float64(v)
        case int64:
            numVal = float64(v)
        case string:
            n, err := strconv.ParseFloat(v, 64)
            if err != nil {
                // 解析失败:使用中位数填充
                medianVal, ok := config.NumericMedians[col]
                if !ok {
                    medianVal = 0.0
                }
                numVal = medianVal
                log.Printf("⚠️ 特征[%s]值错误: %s,填充中位数: %.2f", col, v, medianVal)
            } else {
                numVal = n
            }
        default:
            // 类型错误:使用中位数填充
            medianVal, ok := config.NumericMedians[col]
            if !ok {
                medianVal = 0.0
            }
            numVal = medianVal
            log.Printf("⚠️ 特征[%s]类型错误: %T,填充中位数: %.2f", col, v, medianVal)
        }

        // 负数修复(匹配Python逻辑:所有负数替换为0)
        if numVal < 0 {
            numVal = 0.0
            log.Printf("⚠️ 特征[%s]负数修复为0(原值: %.2f)", col, numVal)
        }

        // ========== 关键修改4:极端值过滤(使用Python保存的分位数) ==========
        if colParams, ok := config.PreprocessingParams[col].(map[string]interface{}); ok {
            // 下界过滤
            if qLow, ok := colParams["q_low"].(float64); ok && numVal < qLow {
                log.Printf("⚠️ 特征[%s]低于下界%.2f,截断为%.2f(原值: %.2f)", col, qLow, qLow, numVal)
                numVal = qLow
            }
            // 上界过滤
            if qHigh, ok := colParams["q_high"].(float64); ok && numVal > qHigh {
                log.Printf("⚠️ 特征[%s]高于上界%.2f,截断为%.2f(原值: %.2f)", col, qHigh, qHigh, numVal)
                numVal = qHigh
            }
        }

        // ========== 关键修改5:特征交互项计算(匹配Python) ==========
        // 先保存原始值,后续计算交互项
        features[idx] = numVal
    }

    // 计算特征交互项(完全匹配Python逻辑)
    p.calculateFeatureInteractions(features, featureIndex, rawData, verifyType)

    return features, nil
}

// calculateFeatureInteractions 计算特征交互项(匹配Python的create_core_feature_interactions)
func (p *CaptchaPredictor) calculateFeatureInteractions(features []float64, featureIndex map[string]int, rawData map[string]interface{}, verifyType string) {
    // 通用交互项:load_to_total_ratio = load_to_first_delay / (total_duration + 1e-6)
    if loadIdx, ok := featureIndex["load_to_first_delay"]; ok {
        if totalIdx, ok := featureIndex["total_duration"]; ok {
            if ratioIdx, ok := featureIndex["load_to_total_ratio"]; ok {
                loadVal := features[loadIdx]
                totalVal := features[totalIdx]
                ratio := loadVal / (totalVal + 1e-6)
                features[ratioIdx] = ratio
                log.Printf("✅ 计算通用交互项 load_to_total_ratio = %.4f", ratio)
            }
        }
    }

    // 差异化交互项
    switch verifyType {
    case "swipe":
        // 滑动验证:load_swipe_ratio = load_to_first_delay / (avg_swipe_interval + 1e-6)
        if loadIdx, ok := featureIndex["load_to_first_delay"]; ok {
            if swipeIdx, ok := featureIndex["avg_swipe_interval"]; ok {
                if ratioIdx, ok := featureIndex["load_swipe_ratio"]; ok {
                    loadVal := features[loadIdx]
                    swipeVal := features[swipeIdx]
                    ratio := loadVal / (swipeVal + 1e-6)
                    features[ratioIdx] = ratio
                    log.Printf("✅ 计算滑动交互项 load_swipe_ratio = %.4f", ratio)
                }
            }
        }

    case "click":
        // 点击验证:duration_click_cv = total_duration * click_interval_cv
        if totalIdx, ok := featureIndex["total_duration"]; ok {
            if clickIdx, ok := featureIndex["click_interval_cv"]; ok {
                if cvIdx, ok := featureIndex["duration_click_cv"]; ok {
                    totalVal := features[totalIdx]
                    clickVal := features[clickIdx]
                    product := totalVal * clickVal
                    features[cvIdx] = product
                    log.Printf("✅ 计算点击交互项 duration_click_cv = %.4f", product)
                }
            }
        }
    }
}

// ========== 8. 预测方法(优化XGBoost推理逻辑) ==========
func (p *CaptchaPredictor) predict(features []float64, verifyType string) (float64, int, error) {
    if !p.initialized {
        return 0, 0, fmt.Errorf("预测器未初始化")
    }

    // 获取对应模型和配置
    model, ok := p.models[verifyType]
    if !ok {
        return 0, 0, fmt.Errorf("未找到%s类型的模型", verifyType)
    }

    config, ok := p.configs[verifyType]
    if !ok {
        return 0, 0, fmt.Errorf("未找到%s类型的配置", verifyType)
    }

    // ========== 关键修改6:正确解析BaseScore ==========
    actualBaseScore := 0.5
    if model.Learner.LearnerModelParam.BaseScore != "" {
        bs, err := strconv.ParseFloat(model.Learner.LearnerModelParam.BaseScore, 64)
        if err == nil {
            actualBaseScore = bs
        }
    }

    // 检查特征长度
    featureCols := config.FeatureColumns
    if len(features) != len(featureCols) {
        return 0, 0, fmt.Errorf("特征长度不匹配:期望%d,实际%d", len(featureCols), len(features))
    }

    // 获取所有决策树
    trees := model.Learner.GradientBooster.Model.Trees
    if len(trees) == 0 {
        log.Println("⚠️ 模型中未找到决策树,返回默认概率0.5")
        return 0.5, 0, nil
    }

    // ========== 关键修改7:优化树遍历逻辑 ==========
    totalScore := 0.0
    validTreeCount := 0

    for treeIdx, tree := range trees {
        currentNodeID := 0 // 从根节点开始
        nodeCount, err := strconv.Atoi(tree.TreeParam.NumNodes)
        if err != nil || nodeCount == 0 {
            log.Printf("⚠️ 树%d节点数无效: %s,跳过", treeIdx, tree.TreeParam.NumNodes)
            continue
        }

        // 遍历节点直到叶节点
        for {
            // 边界检查
            if currentNodeID < 0 || currentNodeID >= len(tree.LeftChildren) {
                break
            }

            // 检查是否是叶节点(左子节点为-1)
            leftChild := tree.LeftChildren[currentNodeID]
            if leftChild == -1 {
                // 累加叶节点权重
                if currentNodeID < len(tree.BaseWeights) {
                    totalScore += tree.BaseWeights[currentNodeID]
                    validTreeCount++
                }
                break
            }

            // 获取分裂特征和阈值
            splitFeatIdx := tree.SplitIndices[currentNodeID]
            splitThreshold := tree.SplitConditions[currentNodeID]

            // 特征索引检查
            if splitFeatIdx < 0 || splitFeatIdx >= len(features) {
                // 特征不存在,走右分支
                currentNodeID = tree.RightChildren[currentNodeID]
                continue
            }

            // 判断分支(匹配XGBoost原生逻辑)
            featVal := features[splitFeatIdx]
            if math.IsNaN(featVal) || math.IsInf(featVal, 0) || featVal <= splitThreshold {
                currentNodeID = leftChild
            } else {
                currentNodeID = tree.RightChildren[currentNodeID]
            }
        }
    }

    // ========== 关键修改8:正确计算逻辑回归概率 ==========
    // XGBoost二分类公式:prob = 1 / (1 + exp(-(base_score_log_odds + sum(tree_scores))))
    baseLogOdds := math.Log(actualBaseScore / (1 - actualBaseScore))
    finalLogOdds := baseLogOdds + totalScore
    predProb := 1.0 / (1.0 + math.Exp(-finalLogOdds))

    // 生成标签(阈值0.5,匹配Python)
    predLabel := 0
    if predProb > 0.5 {
        predLabel = 1
    }

    log.Printf("✅ %s模型推理完成:有效树数=%d,BaseScore=%.4f,总得分=%.4f,概率=%.4f,标签=%d",
        verifyType, validTreeCount, actualBaseScore, totalScore, predProb, predLabel)

    return predProb, predLabel, nil
}

// ========== 9. 统一预测接口(完全匹配Python的predict_captcha) ==========
func (p *CaptchaPredictor) PredictCaptcha(rawData map[string]interface{}) (*PredictResult, error) {
    if !p.initialized {
        return nil, fmt.Errorf("预测器未初始化")
    }

    // 获取验证类型
    verifyType, ok := rawData["verify_type"].(string)
    if !ok {
        verifyType = "default"
    }
    verifyType = strings.ToLower(verifyType)

    // 处理默认类型(匹配Python逻辑)
    if verifyType != "swipe" && verifyType != "click" {
        // 优先使用swipe模型,没有则使用click
        if _, ok := p.models["swipe"]; ok {
            verifyType = "swipe"
        } else if _, ok := p.models["click"]; ok {
            verifyType = "click"
        } else {
            return nil, fmt.Errorf("无可用的验证码模型")
        }
        log.Printf("⚠️ 验证类型%s不支持,自动切换为%s", rawData["verify_type"], verifyType)
    }

    // 预处理(完全匹配Python)
    features, err := p.preprocess(rawData, verifyType)
    if err != nil {
        return nil, fmt.Errorf("预处理失败: %v", err)
    }

    // 预测
    predProb, predLabel, err := p.predict(features, verifyType)
    if err != nil {
        return nil, fmt.Errorf("预测失败: %v", err)
    }

    // 构建结果(完全匹配Python返回格式)
    result := &PredictResult{
        VerifyType:          verifyType,
        AbnormalProbability: predProb,
        PredictedLabel:      predLabel,
        PredictionTimestamp: time.Now().Format("2006-01-02 15:04:05"),
    }

    if predLabel == 1 {
        result.LabelDescription = "异常/疑似"
    } else {
        result.LabelDescription = "正常"
    }

    return result, nil
}
最后修改:2026 年 01 月 12 日 10 : 32 PM
如果觉得我的文章对你有用,请随意赞赏