控制器
var captchaRequest request.VerifyCodeRequest
if err := ctx.ShouldBindWith(&captchaRequest, binding.JSON); err != nil {
helper.Info(err.Error())
ReturnError(ctx, defined.ERROR_500)
return
}
var jsonStr, err = service.CaptchaAESDecrypt(captchaRequest.Ciphertext)
if err != nil {
helper.Error(err)
ReturnSuccess(ctx, nil)
return
}
// 1. 初始化 RawData 变量
var rawData service.CaptchaRawData
// 2. 解析 JSON 字符串到结构体
// Unmarshal 要求 JSON 字段类型与结构体字段类型匹配(如数字 vs int64,字符串 vs string)
err = json.Unmarshal([]byte(jsonStr), &rawData)
if err != nil {
ReturnError(ctx, defined.ERROR_500)
return
}
service.CaptchaConsole(rawData, rawData.Env.SessionId, ctx)
ReturnSuccess(ctx, rawData)
服务
package service
import (
"activity/cache"
"activity/dao"
"activity/etc"
"activity/helper"
"activity/model"
"activity/third/idc"
"activity/third/ip2region"
"activity/utils"
"crypto/aes"
"crypto/cipher"
"encoding/hex"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"gonum.org/v1/gonum/stat"
"math"
)
// 定义常量:统一阈值,便于维护
const (
// 通用规则阈值(无类型差异的基础规则)
DeviceIPBindCountMaxNormal = 2 // 正常IP绑定数上限
DeviceIPBindCountMinRisk = 3 // 风险IP绑定数下限
DeviceIPBindCountMaxRisk = 4 // 风险IP绑定数上限
DeviceIPBindCountMinAbnormal = 5 // 异常IP绑定数下限
IPVerifyFreqMinRisk = 10 // 风险验证频率下限
IPVerifyFreqMaxRisk = 19 // 风险验证频率上限
IPVerifyFreqMinAbnormal = 20 // 异常验证频率下限
// ------------ 按验证类型划分的 load_to_first_delay 阈值(ms)------------
// 滑动(swipe):操作简单,延迟最短
SwipeDelayMinRisk = 80 // 滑动-风险延迟下限(略短)
SwipeDelayMinNormal = 200 // 滑动-正常延迟下限
SwipeDelayMaxNormal = 2000 // 滑动-正常延迟上限
SwipeDelayMaxRisk = 3000 // 滑动-风险延迟上限(略长)
SwipeDelayMinAbnormal = 50 // 滑动-异常延迟下限(过短)
SwipeDelayMaxAbnormal = 4000 // 滑动-异常延迟上限(过长)
// 文字点选(click):需识别目标,延迟中等
ClickDelayMinRisk = 100 // 点选-风险延迟下限
ClickDelayMinNormal = 300 // 点选-正常延迟下限
ClickDelayMaxNormal = 7000 // 点选-正常延迟上限
ClickDelayMaxRisk = 8000 // 点选-风险延迟上限
ClickDelayMinAbnormal = 70 // 点选-异常延迟下限
ClickDelayMaxAbnormal = 9000 // 点选-异常延迟上限
// 文字输入(input):需阅读+输入,延迟最长
InputDelayMinRisk = 200 // 输入-风险延迟下限
InputDelayMinNormal = 500 // 输入-正常延迟下限
InputDelayMaxNormal = 5000 // 输入-正常延迟上限
InputDelayMaxRisk = 7000 // 输入-风险延迟上限
InputDelayMinAbnormal = 100 // 输入-异常延迟下限
InputDelayMaxAbnormal = 10000 // 输入-异常延迟上限(允许更长犹豫时间)
// 其他类型(drag等):默认中等延迟
DefaultDelayMinRisk = 150 // 默认-风险延迟下限
DefaultDelayMinNormal = 350 // 默认-正常延迟下限
DefaultDelayMaxNormal = 3000 // 默认-正常延迟上限
DefaultDelayMaxRisk = 4000 // 默认-风险延迟上限
DefaultDelayMinAbnormal = 80 // 默认-异常延迟下限
DefaultDelayMaxAbnormal = 5000 // 默认-异常延迟上限
// Click类型阈值(其他原有阈值保持不变)
ClickCVLowAbnormal = 0.08 // 点击CV过低(异常)
ClickCVHighAbnormal = 0.85 // 点击CV过高(异常)
ClickCVNormalMin = 0.09 // 点击CV正常下限
// Swipe类型阈值(单位:ms,统一数值与注释)
SwipeIntervalMinAbnormal = 2 // 滑动间隔过低(异常)
SwipeIntervalMaxAbnormal = 12 // 滑动间隔过高(异常)
SwipeIntervalMinRiskFast = 2 // 滑动略快下限(风险)
SwipeIntervalMaxRiskFast = 3 // 滑动略快上限(风险)
SwipeIntervalMinRiskSlow = 10 // 滑动略慢下限(风险)
SwipeIntervalMaxRiskSlow = 12 // 滑动略慢上限(风险)
SwipeIntervalMinNormal = 3 // 滑动正常下限
SwipeIntervalMaxNormal = 10 // 滑动正常上限
// Input类型阈值(单位:ms)
InputIntervalMinAbnormal = 150 // 输入间隔过低(异常)
InputIntervalMaxAbnormal = 2500 // 输入间隔过高(异常)
InputIntervalMinRiskFast = 150 // 输入略快下限(风险)
InputIntervalMaxRiskFast = 400 // 输入略快上限(风险)
InputIntervalMinRiskSlow = 2000 // 输入略慢下限(风险)
InputIntervalMaxRiskSlow = 2500 // 输入略慢上限(风险)
InputIntervalMinNormal = 400 // 输入正常下限
InputIntervalMaxNormal = 2000 // 输入正常上限
)
func CaptchaConsole(raw CaptchaRawData, sessionId string, req *gin.Context) {
if !validateTimestamps(raw) {
helper.Warning("验证码校验数据异常")
return
}
var feature, err = ExtractFeatures(&raw, sessionId, req)
if err != nil {
helper.Error("验证码收集错误", err.Error())
return
}
swipeData := map[string]interface{}{
"verify_type": feature.VerifyType,
"os": feature.OS,
"device_type": feature.DeviceType,
"ip_type": feature.IPType,
"load_to_first_delay": feature.LoadToFirstDelay,
"total_duration": feature.TotalDuration,
"avg_swipe_interval": feature.AvgSwipeInterval,
"ip_verify_freq": feature.IPVerifyFreq,
"device_ip_bind_count": feature.DeviceIPBindCount,
}
swipeResult, err := utils.CaptchaPredict.PredictCaptcha(swipeData)
if err != nil {
log.Fatalf("❌ 验证预测失败: %v", err)
}
helper.Info("预测模型", swipeResult.AbnormalProbability, swipeResult.LabelDescription)
// 3. 分数用途:日志打印/人工复核参考/缓存等
helper.Printf("样本[%s] 反作弊得分:%d,标签:%s", feature.SampleID, feature.TempScore, feature.Label)
dao.GetPlatformDB().Model(&model.CaptchaVerifyFeature{}).Save(&feature)
}
// validateTimestamps 校验时间戳字段是否符合验证类型要求
func validateTimestamps(raw CaptchaRawData) bool {
switch raw.Env.VerifyType {
case "swipe", "click": // 滑动/点选验证
if len(raw.Timing.OperateTimestamps) == 0 {
log.Printf("样本[%s]:滑动/点选验证缺少operateTimestamps", raw.Env.SessionId)
return false
}
case "input", "sms", "text": // 输入类验证
if len(raw.Timing.InputTimestamps) == 0 {
log.Printf("样本[%s]:输入类验证缺少inputTimestamps", raw.Env.SessionId)
return false
}
case "none": // 无感验证
return true // 无要求
default: // 未知验证类型
log.Printf("样本[%s]:未知验证类型[%s]", raw.Env.SessionId, raw.Env.VerifyType)
return false
}
return true
}
// ========== 结构体定义(对应原始数据结构) ==========
// Env 原始数据中的环境信息
type Env struct {
VerifyType string `json:"verifyType"` // 验证码类型
DeviceFingerprint string `json:"deviceFingerprint"` // 设备指纹
OS string `json:"os"` // 操作系统
DeviceType string `json:"deviceType"` // 设备类型
// 新增字段
Browser string `json:"browser"` // 浏览器(含版本)
CookieEnabled bool `json:"cookieEnabled"` // 是否启用Cookie
PluginCount int `json:"pluginCount"` // 插件数量
ScreenRes string `json:"screenRes"` // 屏幕分辨率
SessionId string `json:"sessionId"` // 会话ID
UserAgent string `json:"userAgent"` // 用户代理(加密/哈希后)
}
// Timing 原始数据中的时间统计信息
type Timing struct {
FirstOperateTime int64 `json:"firstOperateTime"` // 首次操作时间戳
LoadTime int64 `json:"loadTime"` // 加载时间戳
SubmitTime int64 `json:"submitTime"` // 提交时间戳
OperateTimestamps []int64 `json:"operateTimestamps"` // 操作时间戳(滑动/点击/拖拽)
InputTimestamps []int64 `json:"inputTimestamps"` // 输入时间戳(输入型验证码)
CorrectCount int `json:"correctCount"` // 输入正确次数
}
// RawData 原始数据根结构
type CaptchaRawData struct {
Env Env `json:"env"`
Timing Timing `json:"timing"`
}
// Request 简化的请求结构体(按需扩展)
type Request struct {
IP string `json:"ip"` // IP地址(用于IP类型/频率查询)
}
// ========== 外部依赖函数(占位,需按业务实现) ==========
// getIPType 第三方接口获取IP类型(IPv4/IPv6/代理等)
func getIPType(req *gin.Context, realIP string) string {
// 实际实现:调用IP解析接口(如ip2region/阿里云IP库)
var tag, err = idc.Search(realIP)
if err != nil {
return "异常"
}
if tag != "" {
return "IDC"
}
if etc.CommonConfigVar.Common.Env != "test" {
if helper.IsViaUntrustedProxy(req) {
return "代理"
}
}
_, ipIfo, _ := ip2region.Search(realIP)
if ipIfo.ISP == "" {
return "未知"
}
return ipIfo.ISP
}
// getDeviceIPBindCount redis查询设备绑定7天的IP数量
func getDeviceIPBindCount(deviceFP string, realIP string) int {
// 实际实现:查询设备指纹-IP绑定表
var key = fmt.Sprintf(cache.CAPTCHA_VERIFY_DEVICE_IP, deviceFP)
cache.Cache.SAddWithExpire(key, realIP, 3600*24*7)
var count, _ = cache.Cache.SCard(key)
return count
}
// getIPVerifyFreq redis查询IP近期4小时验证频率
func getIPVerifyFreq(realIP string) int {
var key = fmt.Sprintf(cache.CAPTCHA_VERIFY_IP_FREQ, realIP)
count, _ := cache.Cache.IncrWithExpire(key, 3600)
return int(count)
}
// DelayThreshold 按验证类型获取 load_to_first_delay 的阈值结构体
type DelayThreshold struct {
MinAbnormal int // 异常延迟下限(低于此值为异常)
MaxAbnormal int // 异常延迟上限(高于此值为异常)
MinRisk int // 风险延迟下限(高于异常下限、低于正常下限)
MaxRisk int // 风险延迟上限(高于正常上限、低于异常上限)
MinNormal int // 正常延迟下限
MaxNormal int // 正常延迟上限
}
// getDelayThresholdByType 根据验证类型获取对应的延迟阈值
func getDelayThresholdByType(verifyType string) DelayThreshold {
switch verifyType {
case "swipe":
return DelayThreshold{
MinAbnormal: SwipeDelayMinAbnormal,
MaxAbnormal: SwipeDelayMaxAbnormal,
MinRisk: SwipeDelayMinRisk,
MaxRisk: SwipeDelayMaxRisk,
MinNormal: SwipeDelayMinNormal,
MaxNormal: SwipeDelayMaxNormal,
}
case "click":
return DelayThreshold{
MinAbnormal: ClickDelayMinAbnormal,
MaxAbnormal: ClickDelayMaxAbnormal,
MinRisk: ClickDelayMinRisk,
MaxRisk: ClickDelayMaxRisk,
MinNormal: ClickDelayMinNormal,
MaxNormal: ClickDelayMaxNormal,
}
case "input":
return DelayThreshold{
MinAbnormal: InputDelayMinAbnormal,
MaxAbnormal: InputDelayMaxAbnormal,
MinRisk: InputDelayMinRisk,
MaxRisk: InputDelayMaxRisk,
MinNormal: InputDelayMinNormal,
MaxNormal: InputDelayMaxNormal,
}
default: // drag 等其他类型
return DelayThreshold{
MinAbnormal: DefaultDelayMinAbnormal,
MaxAbnormal: DefaultDelayMaxAbnormal,
MinRisk: DefaultDelayMinRisk,
MaxRisk: DefaultDelayMaxRisk,
MinNormal: DefaultDelayMinNormal,
MaxNormal: DefaultDelayMaxNormal,
}
}
}
func autoLabel(features *model.CaptchaVerifyFeature, timing *Timing) (string, int) {
// 入参校验:特征为空时,返回risk+0分
if features == nil {
return "risk", 0
}
// 消除未使用参数警告
_ = timing
// 初始化反作弊基础得分(100分)
score := 100
verifyType := features.VerifyType
// 获取当前类型的专属延迟阈值
delayThreshold := getDelayThresholdByType(verifyType)
// 当前样本的加载到首次操作延迟(转换为int类型,解决类型不匹配)
currentDelay := int(features.LoadToFirstDelay) // 关键修改:int64 -> int
// ========== 第一步:通用异常规则(触发直接返回abnormal) ==========
// 规则1:加载到首次操作延迟 低于类型专属异常下限(过短,非人类操作)
if currentDelay < delayThreshold.MinAbnormal { // 现在都是int类型,可直接比较
score -= 30
return "abnormal", score
}
// 规则2:加载到首次操作延迟 高于类型专属异常上限(过长,疑似机器卡顿时效)
if currentDelay > delayThreshold.MaxAbnormal { // 类型一致
score -= 25
return "abnormal", score
}
// 规则3:设备绑定IP数 ≥5(多IP作弊)
if features.DeviceIPBindCount >= DeviceIPBindCountMinAbnormal {
score -= 20
return "abnormal", score
}
// 规则4:IP验证频率 ≥20(高频验证作弊)
if features.IPVerifyFreq >= IPVerifyFreqMinAbnormal {
score -= 15
return "abnormal", score
}
// 规则5:IP类型为高危类型(代理/异常等)
if helper.Contains([]string{"异常", "代理", "未知", "IDC"}, features.IPType) {
score -= 50
return "abnormal", score
}
// ========== 第二步:按验证码类型判定专属异常规则 ==========
// 标记是否触发专属异常规则
hasExclusiveAbnormal := false
switch verifyType {
case "click": // 文字点选
// 负类规则:点击间隔变异系数超出异常区间
if features.ClickIntervalCV != -1 {
if features.ClickIntervalCV < ClickCVLowAbnormal || features.ClickIntervalCV > ClickCVHighAbnormal {
score -= 25
hasExclusiveAbnormal = true
} else if features.ClickIntervalCV >= ClickCVLowAbnormal && features.ClickIntervalCV < ClickCVNormalMin {
// 模糊规则:接近异常区间,先扣分(不直接返回)
score -= 10
}
}
case "swipe": // 滑动验证
// 负类规则:平均滑动间隔 <2ms 或 >12ms(机械操作/异常慢滑)
if features.AvgSwipeInterval != -1 &&
(features.AvgSwipeInterval < SwipeIntervalMinAbnormal || features.AvgSwipeInterval > SwipeIntervalMaxAbnormal) {
score -= 25
hasExclusiveAbnormal = true
} else if features.AvgSwipeInterval != -1 {
// 模糊规则:2~3ms(略快)或 10~12ms(略慢)
if (features.AvgSwipeInterval >= SwipeIntervalMinRiskFast && features.AvgSwipeInterval < SwipeIntervalMaxRiskFast) ||
(features.AvgSwipeInterval >= SwipeIntervalMinRiskSlow && features.AvgSwipeInterval <= SwipeIntervalMaxRiskSlow) {
score -= 8
}
}
case "input": // 文字输入
// 负类规则:平均输入间隔 <150ms(自动填充)或 >2500ms(异常慢)
if features.AvgInputInterval != -1 &&
(features.AvgInputInterval < InputIntervalMinAbnormal || features.AvgInputInterval > InputIntervalMaxAbnormal) {
score -= 25
hasExclusiveAbnormal = true
} else if features.AvgInputInterval != -1 {
// 模糊规则:150~400ms(略快)或 2000~2500ms(略慢)
if (features.AvgInputInterval >= InputIntervalMinRiskFast && features.AvgInputInterval < InputIntervalMaxRiskFast) ||
(features.AvgInputInterval >= InputIntervalMinRiskSlow && features.AvgInputInterval <= InputIntervalMaxRiskSlow) {
score -= 8
}
}
}
// 触发专属异常规则,直接返回abnormal
if hasExclusiveAbnormal {
return "abnormal", score
}
// ========== 第三步:正类规则(真实用户特征,触发返回normal) ==========
// 正类基础条件:加载延迟正常 + 设备绑定IP数正常(类型一致)
baseNormalCondition := currentDelay >= delayThreshold.MinNormal &&
currentDelay <= delayThreshold.MaxNormal &&
features.DeviceIPBindCount <= DeviceIPBindCountMaxNormal
// 按验证码类型判定正类
switch verifyType {
case "click":
if baseNormalCondition && features.ClickIntervalCV != -1 && features.ClickIntervalCV > ClickCVNormalMin {
score += 10
return "normal", score
}
case "swipe":
if baseNormalCondition && features.AvgSwipeInterval != -1 &&
features.AvgSwipeInterval >= SwipeIntervalMinNormal && features.AvgSwipeInterval <= SwipeIntervalMaxNormal {
score += 10
return "normal", score
}
case "input":
if baseNormalCondition && features.AvgInputInterval != -1 &&
features.AvgInputInterval >= InputIntervalMinNormal && features.AvgInputInterval <= InputIntervalMaxNormal {
score += 10
return "normal", score
}
default:
// 其他类型(如drag):仅判断基础条件
if baseNormalCondition {
score += 10
return "normal", score
}
}
// ========== 第四步:核心风险判定(基于类型专属延迟区间) ==========
// 风险判定核心条件:满足任一即归为risk(类型一致)
var isRisk bool
// 条件1:延迟在「异常下限~正常下限」之间(略短)
delayRiskLow := currentDelay >= delayThreshold.MinAbnormal && currentDelay < delayThreshold.MinNormal
// 条件2:延迟在「正常上限~异常上限」之间(略长)
delayRiskHigh := currentDelay > delayThreshold.MaxNormal && currentDelay <= delayThreshold.MaxAbnormal
// 条件3:设备绑定IP数超出正常区间(3-4个)
ipBindRisk := features.DeviceIPBindCount >= DeviceIPBindCountMinRisk && features.DeviceIPBindCount <= DeviceIPBindCountMaxRisk
// 条件4:IP验证频率超出正常区间(10-19次)
ipFreqRisk := features.IPVerifyFreq >= IPVerifyFreqMinRisk && features.IPVerifyFreq <= IPVerifyFreqMaxRisk
// 满足任一风险条件,标记为risk
isRisk = delayRiskLow || delayRiskHigh || ipBindRisk || ipFreqRisk
// ========== 第五步:风险样本微调扣分(按延迟类型差异化扣分) ==========
if isRisk {
// 延迟略短:根据类型调整扣分(滑动操作对延迟敏感,扣分更多)
if delayRiskLow {
switch verifyType {
case "swipe":
score -= 8 // 滑动延迟略短,疑似脚本
case "click":
score -= 6 // 点选延迟略短,中等风险
case "input":
score -= 4 // 输入延迟略短,风险较低
default:
score -= 5
}
}
// 延迟略长:输入类型允许更长犹豫,扣分更少
if delayRiskHigh {
switch verifyType {
case "swipe":
score -= 6 // 滑动延迟略长,疑似卡顿/机器
case "click":
score -= 5 // 点选延迟略长,中等风险
case "input":
score -= 3 // 输入延迟略长,正常犹豫
default:
score -= 4
}
}
// 设备绑定IP数略高:中度扣分
if ipBindRisk {
score -= 8
}
// IP验证频率略高:轻微扣分
if ipFreqRisk {
score -= 5
}
// 风险样本直接返回(兜底前的明确判定)
return "risk", score
}
// ========== 兜底规则:未触发任何规则的边缘样本,仍归为risk ==========
return "risk", score
}
// ========== 核心特征提取函数 ==========
// ExtractFeatures 提取并加工特征
func ExtractFeatures(rawData *CaptchaRawData, sessionID string, req *gin.Context) (*model.CaptchaVerifyFeature, error) {
// 1. 提取原始基础数据
env := rawData.Env
timing := rawData.Timing
verifyType := env.VerifyType
realIP := helper.GetIp(req)
helper.Info("ip:", realIP)
// 2. 构建基础特征
features := &model.CaptchaVerifyFeature{
SampleID: uuid.NewString(), // 生成UUID(对应Python uuid.uuid4())
SessionID: sessionID,
VerifyType: verifyType,
LoadToFirstDelay: timing.FirstOperateTime - timing.LoadTime,
TotalDuration: timing.SubmitTime - timing.LoadTime,
DeviceFingerprint: env.DeviceFingerprint,
OS: env.OS,
DeviceType: env.DeviceType,
IPType: getIPType(req, realIP),
DeviceIPBindCount: getDeviceIPBindCount(env.DeviceFingerprint, realIP),
IPVerifyFreq: getIPVerifyFreq(realIP),
Ip: realIP,
}
// 3. 按验证码类型添加专属特征
switch verifyType {
case "swipe", "click", "drag":
operateTS := timing.OperateTimestamps
features.OperateStepCount = len(operateTS)
if len(operateTS) > 1 {
// 计算操作时间间隔(对应Python np.diff)
intervals := make([]float64, len(operateTS)-1)
for i := 0; i < len(operateTS)-1; i++ {
intervals[i] = float64(operateTS[i+1] - operateTS[i])
}
switch verifyType {
case "click":
// 计算点击间隔变异系数(标准差/均值,除零保护)
mean := stat.Mean(intervals, nil)
if mean == 0 {
features.ClickIntervalCV = 0
} else {
std := math.Sqrt(stat.Variance(intervals, nil))
features.ClickIntervalCV = std / mean
}
case "swipe":
// 计算平均滑动间隔(转int)
mean := stat.Mean(intervals, nil)
features.AvgSwipeInterval = int(mean)
}
}
case "input":
inputTS := timing.InputTimestamps
features.InputCount = len(inputTS)
// 计算正确率(除零保护)
if len(inputTS) > 0 {
features.CorrectRate = (float64(timing.CorrectCount) / float64(len(inputTS))) * 100
} else {
features.CorrectRate = 0
}
// 计算平均输入间隔
if len(inputTS) > 1 {
intervals := make([]float64, len(inputTS)-1)
for i := 0; i < len(inputTS)-1; i++ {
intervals[i] = float64(inputTS[i+1] - inputTS[i])
}
mean := stat.Mean(intervals, nil)
features.AvgInputInterval = int(mean)
}
}
// 4. 自动标注标签
features.Label, features.TempScore = autoLabel(features, &timing)
return features, nil
}
// 配置参数(与前端保持完全一致)
const (
appKey = "53D8DBC5DIK3436A" // 加密密钥
appIv = "KI5JL2SKE9883365" // 加密IV(不再从key截取,直接使用独立的appIv)
)
// PKCS7Unpad 去除PKCS7填充(CryptoJS的Pkcs7对应标准PKCS7填充)
// AES块大小固定为16字节,填充规则:填充n个字节则每个字节值为n
func pKCS7Unpad(data []byte) ([]byte, error) {
length := len(data)
if length == 0 {
return nil, errors.New("填充数据为空")
}
// 获取最后一个字节的填充长度
unpadLen := int(data[length-1])
// 验证填充长度合法性(必须在1~16之间)
if unpadLen < 1 || unpadLen > aes.BlockSize {
return nil, errors.New("无效的填充长度")
}
// 验证所有填充字节是否符合PKCS7规则
for i := 0; i < unpadLen; i++ {
if data[length-1-i] != byte(unpadLen) {
return nil, errors.New("填充数据不合法")
}
}
// 去除填充部分,返回原始数据
return data[:length-unpadLen], nil
}
// AESDecrypt 解密前端加密的数据(适配最新的JS逻辑:IV使用appIv)
// hexCiphertext: 前端返回的ciphertext(Hex编码字符串)
// 返回值:解密后的原始数据字节、错误信息
func CaptchaAESDecrypt(hexCiphertext string) ([]byte, error) {
// 1. 解码Hex字符串为密文字节(对应前端的toString(CryptoJS.enc.Hex))
cipherData, err := hex.DecodeString(hexCiphertext)
if err != nil {
return nil, fmt.Errorf("Hex解码失败: %w", err)
}
// 2. 处理密钥和IV(与最新JS逻辑严格对齐)
// 密钥:appKey的UTF8字节(对应CryptoJS.enc.Utf8.parse(appKey))
key := []byte(appKey)
// IV:appIv的UTF8字节(对应CryptoJS.enc.Utf8.parse(appIv))
iv := []byte(appIv)
// 验证IV长度(AES-CBC要求IV必须等于块大小16字节)
if len(iv) != aes.BlockSize {
return nil, fmt.Errorf("IV长度非法,要求16字节,实际%d字节", len(iv))
}
// 3. 创建AES密码块
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("创建AES块失败: %w", err)
}
// 4. 验证密文长度(必须是AES块大小的整数倍)
if len(cipherData)%aes.BlockSize != 0 {
return nil, errors.New("密文长度不是AES块大小的整数倍")
}
// 5. 创建CBC模式解密器(对应前端的CryptoJS.mode.CBC)
mode := cipher.NewCBCDecrypter(block, iv)
// 6. 执行解密(CBC模式解密会直接修改目标切片)
plainData := make([]byte, len(cipherData))
mode.CryptBlocks(plainData, cipherData)
// 7. 去除PKCS7填充(对应前端的CryptoJS.pad.Pkcs7)
plainData, err = pKCS7Unpad(plainData)
if err != nil {
return nil, fmt.Errorf("去除填充失败: %w", err)
}
return plainData, nil
}
训练
import pandas as pd
import numpy as np
import warnings
import matplotlib.pyplot as plt
import optuna
import os
import pickle
import json # 新增JSON库
import platform
from typing import Dict, List, Optional, Tuple
import time
import re
# 替换为XGBoost相关库
import xgboost as xgb
from xgboost.callback import TrainingCallback # 导入回调基类
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, roc_auc_score
from imblearn.over_sampling import SMOTE
# 忽略警告
warnings.filterwarnings('ignore')
# ====================== 修复字体异常(核心) ======================
# 适配不同系统的中文字体,解决SimHei找不到问题
def set_matplotlib_chinese_font():
system = platform.system()
try:
if system == 'Darwin': # macOS
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'PingFang SC']
elif system == 'Windows': # Windows
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei']
elif system == 'Linux': # Linux
plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'WenQuanYi Micro Hei']
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示
print(f"✅ 字体设置完成(系统:{system}),当前字体:{plt.rcParams['font.sans-serif']}")
except Exception as e:
print(f"⚠️ 字体设置失败:{e},使用默认字体")
# 执行字体设置
set_matplotlib_chinese_font()
# 开启Matplotlib交互式模式(实时绘图)
plt.ion()
# ====================== 新增:差异化配置 ======================
# 不同verify_type的差异化参数配置
VERIFY_TYPE_CONFIGS = {
'swipe': {
# Swipe(滑动验证)参数:更关注时间序列特征
'model_params': {
'max_depth': 5,
'learning_rate': 0.015,
'subsample': 0.75,
'colsample_bytree': 0.75,
'reg_alpha': 0.2,
'reg_lambda': 0.2,
'min_child_weight': 4,
'gamma': 0.01,
'n_estimators': 250
},
'core_features': ['load_to_first_delay', 'total_duration', 'avg_swipe_interval','ip_verify_freq','device_ip_bind_count'],
'smote_strategy': 0.7, # 过采样比例
'extreme_quantile': (0.02, 0.98) # 极端值过滤分位数
},
'click': {
# Click(点击验证)参数:更关注点击行为特征
'model_params': {
'max_depth': 4,
'learning_rate': 0.02,
'subsample': 0.7,
'colsample_bytree': 0.7,
'reg_alpha': 0.3,
'reg_lambda': 0.3,
'min_child_weight': 6,
'gamma': 0.03,
'n_estimators': 200
},
'core_features': ['load_to_first_delay', 'total_duration', 'click_interval_cv','ip_verify_freq','device_ip_bind_count'],
'smote_strategy': 0.8, # 过采样比例
'extreme_quantile': (0.01, 0.99) # 极端值过滤分位数
},
'default': {
# 默认配置
'model_params': {
'max_depth': 4,
'learning_rate': 0.02,
'subsample': 0.7,
'colsample_bytree': 0.7,
'reg_alpha': 0.3,
'reg_lambda': 0.3,
'min_child_weight': 5,
'gamma': 0.02,
'n_estimators': 200
},
'core_features': ['load_to_first_delay', 'total_duration', 'ip_verify_freq','device_ip_bind_count'],
'smote_strategy': 0.8,
'extreme_quantile': (0.01, 0.99)
}
}
# 需要过滤的无用列(包含字符串ID、IP、时间等)
USELESS_COLS = [
'id', 'sample_id', 'session_id', 'ip', 'device_fingerprint',
'temp_score', 'created_at', 'updated_at', 'user_id', 'device_id',
'phone_number', 'email', 'cookie_id', 'session_token'
]
# 全局变量:存储不同verify_type的训练过程数据
train_metrics = {
'swipe': {
'train_loss': [], 'val_loss': [], 'train_auc': [], 'val_auc': [],
'fig': None, 'ax1': None, 'ax2': None, 'ax3': None, 'ax4': None,
'lines': {}, 'core_feature_importance': {}, 'importance_iterations': []
},
'click': {
'train_loss': [], 'val_loss': [], 'train_auc': [], 'val_auc': [],
'fig': None, 'ax1': None, 'ax2': None, 'ax3': None, 'ax4': None,
'lines': {}, 'core_feature_importance': {}, 'importance_iterations': []
}
}
# 全局变量:存储不同verify_type的预处理参数
preprocessing_params = {
'swipe': {},
'click': {},
'default': {}
}
# 初始化各verify_type的核心特征重要性存储
for vt in ['swipe', 'click']:
config = VERIFY_TYPE_CONFIGS.get(vt, VERIFY_TYPE_CONFIGS['default'])
core_features = config['core_features']
for feat in core_features:
train_metrics[vt]['core_feature_importance'][feat] = []
# ====================== 新增:XGBoost模型转JSON工具函数 ======================
def convert_xgb_model_to_json(model: xgb.Booster, output_path: str):
"""
将XGBoost模型转换为标准JSON格式(兼容Go解析)
Args:
model: 训练好的XGBoost Booster模型
output_path: JSON文件输出路径
"""
try:
# 1. 导出模型为JSON字符串(XGBoost原生JSON格式)
model_json_str = model.save_raw("json")
# 2. 解析为Python字典(确保格式规范)
model_json = json.loads(model_json_str)
# 3. 保存为格式化的JSON文件
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(model_json, f, indent=4, ensure_ascii=False)
print(f"✅ XGBoost模型已转换为JSON格式:{output_path}")
return True
except Exception as e:
print(f"❌ XGBoost模型转JSON失败:{e}")
return False
# ====================== 工具函数 ======================
def is_ip_address(s):
"""判断字符串是否为IP地址"""
if not isinstance(s, str):
return False
ip_pattern = r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$'
return re.match(ip_pattern, s) is not None
def is_numeric_column(series):
"""判断列是否为数值类型(可以安全转换为数字)"""
# 先移除空值
non_null_series = series.dropna()
if len(non_null_series) == 0:
return False
# 检查是否包含IP地址
if non_null_series.apply(lambda x: is_ip_address(x) if isinstance(x, str) else False).any():
return False
# 尝试转换为数值
try:
pd.to_numeric(non_null_series, errors='raise')
return True
except (ValueError, TypeError):
return False
def get_verify_type_config(verify_type: str) -> dict:
"""获取指定verify_type的配置"""
return VERIFY_TYPE_CONFIGS.get(verify_type, VERIFY_TYPE_CONFIGS['default'])
def init_realtime_plot(verify_type: str):
"""初始化指定verify_type的实时绘图画布"""
if train_metrics[verify_type]['fig'] is None:
# 获取该类型的核心特征
config = get_verify_type_config(verify_type)
core_features = config['core_features']
# 调整布局
fig = plt.figure(figsize=(15, 10))
fig.suptitle(f'模型训练监控 - Verify Type: {verify_type.upper()}', fontsize=16)
# 子图1:LogLoss
ax1 = plt.subplot(2, 2, 1)
ax1.set_xlabel('迭代次数')
ax1.set_ylabel('LogLoss')
ax1.set_title('训练/验证集LogLoss(越小越好)')
ax1.grid(True, alpha=0.3)
line1, = ax1.plot([], [], label='训练集', color='blue', linewidth=1.5)
line2, = ax1.plot([], [], label='验证集', color='red', linewidth=1.5)
ax1.legend()
# 子图2:AUC
ax2 = plt.subplot(2, 2, 2)
ax2.set_xlabel('迭代次数')
ax2.set_ylabel('AUC')
ax2.set_title('训练/验证集AUC(越大越好)')
ax2.grid(True, alpha=0.3)
line3, = ax2.plot([], [], label='训练集', color='blue', linewidth=1.5)
line4, = ax2.plot([], [], label='验证集', color='red', linewidth=1.5)
ax2.legend()
# 子图3:核心特征重要性趋势
ax3 = plt.subplot(2, 2, 3)
ax3.set_xlabel('迭代次数')
ax3.set_ylabel('特征重要性(Gain)')
ax3.set_title(f'核心特征重要性趋势({"/".join(core_features)})')
ax3.grid(True, alpha=0.3)
# 为每个核心特征创建线条
lines = {}
colors = ['green', 'orange', 'purple', 'brown']
for i, feat in enumerate(core_features):
line, = ax3.plot([], [], label=feat, color=colors[i % len(colors)], linewidth=1.5)
lines[feat] = line
ax3.legend()
# 子图4:核心特征分布对比
ax4 = plt.subplot(2, 2, 4)
ax4.set_title('核心特征分布(正常vs异常)')
ax4.set_visible(False)
# 保存到全局变量
train_metrics[verify_type]['fig'] = fig
train_metrics[verify_type]['ax1'] = ax1
train_metrics[verify_type]['ax2'] = ax2
train_metrics[verify_type]['ax3'] = ax3
train_metrics[verify_type]['ax4'] = ax4
train_metrics[verify_type]['lines']['train_loss'] = line1
train_metrics[verify_type]['lines']['val_loss'] = line2
train_metrics[verify_type]['lines']['train_auc'] = line3
train_metrics[verify_type]['lines']['val_auc'] = line4
for feat, line in lines.items():
train_metrics[verify_type]['lines'][feat] = line
def update_realtime_plot(verify_type: str, iteration, train_loss, val_loss, train_auc, val_auc, model=None):
"""更新指定verify_type的实时绘图数据"""
metrics = train_metrics[verify_type]
config = get_verify_type_config(verify_type)
core_features = config['core_features']
# 添加基础指标
metrics['train_loss'].append(train_loss)
metrics['val_loss'].append(val_loss)
metrics['train_auc'].append(train_auc)
metrics['val_auc'].append(val_auc)
# 更新基础线条
x_data = list(range(1, len(metrics['train_loss']) + 1))
if len(x_data) == len(metrics['train_loss']):
metrics['lines']['train_loss'].set_data(x_data, metrics['train_loss'])
if len(x_data) == len(metrics['val_loss']):
metrics['lines']['val_loss'].set_data(x_data, metrics['val_loss'])
if len(x_data) == len(metrics['train_auc']):
metrics['lines']['train_auc'].set_data(x_data, metrics['train_auc'])
if len(x_data) == len(metrics['val_auc']):
metrics['lines']['val_auc'].set_data(x_data, metrics['val_auc'])
# 更新核心特征重要性
if model and iteration % 10 == 0:
try:
importance_dict = model.get_score(importance_type='gain')
metrics['importance_iterations'].append(iteration)
for feat in core_features:
imp_value = importance_dict.get(feat, 0)
metrics['core_feature_importance'][feat].append(imp_value)
# 绘制核心特征重要性
imp_iter = metrics['importance_iterations']
for feat in core_features:
imp_vals = metrics['core_feature_importance'][feat]
if len(imp_iter) == len(imp_vals) and len(imp_iter) > 0:
metrics['lines'][feat].set_data(imp_iter, imp_vals)
metrics['ax3'].relim()
metrics['ax3'].autoscale_view()
except Exception as e:
print(f"\n⚠️ {verify_type} 特征重要性绘图更新失败:{e}")
# 自动调整坐标轴
metrics['ax1'].relim()
metrics['ax1'].autoscale_view()
metrics['ax2'].relim()
metrics['ax2'].autoscale_view()
# 刷新画布
try:
metrics['fig'].canvas.draw()
metrics['fig'].canvas.flush_events()
except Exception as e:
print(f"\n⚠️ {verify_type} 绘图刷新失败:{e}")
time.sleep(0.01)
# ====================== 差异化回调类 ======================
class XGBoostRealtimePlotCallback(TrainingCallback):
"""差异化的XGBoost实时绘图回调类"""
def __init__(self, verify_type: str):
super().__init__()
self.iteration = 0
self.verify_type = verify_type
def after_iteration(self, model, epoch, evals_log):
"""每轮迭代后执行的回调方法"""
self.iteration = epoch
try:
# 提取训练和验证指标
train_loss = val_loss = train_auc = val_auc = None
if 'train' in evals_log and 'logloss' in evals_log['train']:
train_loss = evals_log['train']['logloss'][-1]
if 'eval' in evals_log and 'logloss' in evals_log['eval']:
val_loss = evals_log['eval']['logloss'][-1]
if 'train' in evals_log and 'auc' in evals_log['train']:
train_auc = evals_log['train']['auc'][-1]
if 'eval' in evals_log and 'auc' in evals_log['eval']:
val_auc = evals_log['eval']['auc'][-1]
# 更新绘图
if all([train_loss, val_loss, train_auc, val_auc]):
update_realtime_plot(self.verify_type, self.iteration,
train_loss, val_loss, train_auc, val_auc, model)
except Exception as e:
print(f"\n⚠️ {self.verify_type} 回调函数执行失败(迭代{epoch}):{e}")
return False
# ====================== 数据处理函数(修复版本) ======================
def load_and_validate_data(csv_file: str) -> pd.DataFrame:
"""加载并校验验证码行为数据"""
try:
df = pd.read_csv(csv_file, encoding='utf-8')
print(f"✅ 数据加载成功,形状: {df.shape}")
print(f"✅ verify_type分布:\n{df['verify_type'].value_counts()}")
# 检查关键列
required_cols = [
'label', 'verify_type', 'load_to_first_delay', 'total_duration'
]
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
raise ValueError(f"❌ 缺少关键列:{missing_cols}")
# 过滤无效的verify_type
valid_verify_types = ['swipe', 'click']
df = df[df['verify_type'].isin(valid_verify_types)]
print(f"✅ 过滤后有效数据量:{len(df)}(仅保留swipe/click)")
# 提前过滤无用列(修复核心问题)
df = filter_useless_columns(df)
return df
except Exception as e:
raise Exception(f"❌ 加载数据失败:{str(e)}")
def filter_useless_columns(df: pd.DataFrame) -> pd.DataFrame:
"""过滤无用列(包含字符串ID、IP、时间等)"""
df_filtered = df.copy()
# 找出所有需要删除的列
cols_to_drop = [col for col in USELESS_COLS if col in df_filtered.columns]
if cols_to_drop:
df_filtered = df_filtered.drop(columns=cols_to_drop)
print(f"\n🗑️ 已删除无用列:{cols_to_drop}")
# 额外检查:识别并移除包含IP地址的列
ip_columns = []
for col in df_filtered.columns:
# 检查列名是否包含ip相关关键词
if 'ip' in col.lower() and col != 'ip_type' and col != 'ip_verify_freq' and col != 'device_ip_bind_count':
# 抽样检查列内容是否包含IP地址
sample_data = df_filtered[col].dropna().head(10)
if sample_data.apply(lambda x: is_ip_address(x) if isinstance(x, str) else False).any():
ip_columns.append(col)
if ip_columns:
df_filtered = df_filtered.drop(columns=ip_columns)
print(f"🗑️ 已删除包含IP地址的列:{ip_columns}")
return df_filtered
def filter_extreme_samples(df: pd.DataFrame, verify_type: str, cols: list) -> pd.DataFrame:
"""差异化过滤极端值样本"""
config = get_verify_type_config(verify_type)
lower_q, upper_q = config['extreme_quantile']
df_filtered = df.copy()
total_samples = len(df_filtered)
# 初始化该类型的预处理参数
preprocessing_params[verify_type] = {}
for col in cols:
if col not in df_filtered.columns:
print(f"⚠️ {verify_type} 列{col}不存在,跳过过滤")
continue
# 确保列是数值类型
if not is_numeric_column(df_filtered[col]):
print(f"⚠️ {verify_type} 列{col}不是数值类型,跳过过滤")
continue
q_low = df_filtered[col].quantile(lower_q)
q_high = df_filtered[col].quantile(upper_q)
# 保存分位数参数
preprocessing_params[verify_type][col] = {
'q_low': q_low, 'q_high': q_high,
'mean': df_filtered[col].mean(), 'std': df_filtered[col].std(),
'median': df_filtered[col].median() # 新增中位数保存
}
# 过滤条件
filter_condition = (df_filtered[col] >= q_low) & (df_filtered[col] <= q_high)
df_filtered = df_filtered[filter_condition]
filtered_count = total_samples - len(df_filtered)
total_samples = len(df_filtered)
print(f"\n🔧 [{verify_type}] {col} 极端值过滤:")
print(f" 分位数范围:{q_low:.2f} ~ {q_high:.2f}")
print(f" 过滤样本数:{filtered_count},剩余样本数:{total_samples}")
return df_filtered
def process_labels(df: pd.DataFrame) -> pd.DataFrame:
"""处理验证码标签"""
print(f"\n📊 原始标签分布:")
label_counts = df['label'].value_counts()
print(label_counts)
label_mapping = {'normal': 0, 'risk': 1, 'abnormal': 1}
df['binary_label'] = df['label'].map(lambda x: label_mapping.get(x, 1))
print(f"\n📊 转换后的二进制标签分布:")
print(df['binary_label'].value_counts())
return df
def select_features(df: pd.DataFrame, verify_type: str) -> Tuple[List[str], List[str], List[str]]:
"""差异化选择特征(修复版本)"""
config = get_verify_type_config(verify_type)
core_features = config['core_features']
# 排除标签列和verify_type
feature_cols = [col for col in df.columns if col not in ['label', 'binary_label', 'verify_type']]
# 自动识别分类特征和数值特征
categorical_cols = []
numeric_cols = []
for col in feature_cols:
# 检查是否为数值列
if is_numeric_column(df[col]):
numeric_cols.append(col)
else:
categorical_cols.append(col)
print(f"\n🔧 [{verify_type}] 自动识别特征类型:")
print(f" 分类特征: {categorical_cols}")
# 标记核心特征
marked_numeric = [f"⭐{feat}" if feat in core_features else feat for feat in numeric_cols]
print(f" 数值特征: {marked_numeric}")
# 检查核心特征是否存在
missing_core = [feat for feat in core_features if feat not in numeric_cols]
if missing_core:
print(f"⚠️ [{verify_type}] 核心特征缺失:{missing_core}")
return feature_cols, categorical_cols, numeric_cols
def fix_numeric_outliers(df: pd.DataFrame, verify_type: str, numeric_cols: list) -> pd.DataFrame:
"""差异化修复数值特征异常值"""
config = get_verify_type_config(verify_type)
core_features = config['core_features']
df_fixed = df.copy()
# 处理核心特征
for col in core_features:
if col in numeric_cols and col in df_fixed.columns:
# 确保列是数值类型
df_fixed[col] = pd.to_numeric(df_fixed[col], errors='coerce')
# 负数替换为0
neg_count = (df_fixed[col] < 0).sum()
if neg_count > 0:
df_fixed.loc[df_fixed[col] < 0, col] = 0
print(f"\n🔧 [{verify_type}] 修复{col}负数:{neg_count} 个值")
# 处理其他数值特征
other_numeric = [col for col in numeric_cols if col not in core_features and col in df_fixed.columns]
for col in other_numeric:
# 确保列是数值类型
df_fixed[col] = pd.to_numeric(df_fixed[col], errors='coerce')
Q1 = df_fixed[col].quantile(0.25)
Q3 = df_fixed[col].quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR
outlier_mask = (df_fixed[col] < lower_bound) | (df_fixed[col] > upper_bound)
outlier_count = outlier_mask.sum()
if outlier_count > 0:
df_fixed.loc[outlier_mask, col] = df_fixed[col].clip(lower_bound, upper_bound)
print(f"🔧 [{verify_type}] 修复{col}异常值:{outlier_count} 个值")
return df_fixed
def handle_missing_values(df: pd.DataFrame, verify_type: str, numeric_cols: list,
categorical_cols: list) -> pd.DataFrame:
"""差异化填充缺失值(修复版本)"""
config = get_verify_type_config(verify_type)
core_features = config['core_features']
df_filled = df.copy()
# 记录数值特征的中位数(用于Go端填充)
numeric_medians = {}
# 记录分类特征的众数(用于Go端填充)
categorical_modes = {}
# 填充核心特征
for col in core_features:
if col in numeric_cols and col in df_filled.columns:
# 确保列是数值类型
df_filled[col] = pd.to_numeric(df_filled[col], errors='coerce')
numeric_medians[col] = df_filled[col].median()
missing_count = df_filled[col].isnull().sum()
if missing_count > 0:
if 'delay' in col or 'duration' in col:
fill_val = df_filled[col].median()
df_filled[col] = df_filled[col].fillna(fill_val)
print(f"\n🔧 [{verify_type}] 填充{col}缺失值:{missing_count} 个值,中位数={fill_val:.2f}")
else:
fill_val = 0 if 'freq' in col else df_filled[col].median()
df_filled[col] = df_filled[col].fillna(fill_val)
print(f"\n🔧 [{verify_type}] 填充{col}缺失值:{missing_count} 个值,默认值={fill_val:.2f}")
# 填充其他数值特征
for col in [c for c in numeric_cols if c not in core_features and c in df_filled.columns]:
df_filled[col] = pd.to_numeric(df_filled[col], errors='coerce')
numeric_medians[col] = df_filled[col].median()
missing_count = df_filled[col].isnull().sum()
if missing_count > 0:
fill_val = df_filled[col].median()
df_filled[col] = df_filled[col].fillna(fill_val)
# 填充分类特征
for col in categorical_cols:
if col in df_filled.columns:
if df_filled[col].isnull().sum() > 0:
mode_val = df_filled[col].mode()[0]
categorical_modes[col] = mode_val
df_filled[col] = df_filled[col].fillna(mode_val)
else:
categorical_modes[col] = 'Unknown'
# 保存中位数和众数到预处理参数
preprocessing_params[verify_type]['numeric_medians'] = numeric_medians
preprocessing_params[verify_type]['categorical_modes'] = categorical_modes
return df_filled
def encode_categorical_features(df: pd.DataFrame, categorical_cols: list) -> Tuple[
pd.DataFrame, Dict[str, LabelEncoder]]:
"""编码分类特征"""
label_encoders = {}
df_encoded = df.copy()
for col in categorical_cols:
if col in df_encoded.columns:
le = LabelEncoder()
# 填充未知值为'Unknown'
df_encoded[col] = df_encoded[col].fillna('Unknown').astype(str)
# 确保'Unknown'在类别中
if 'Unknown' not in df_encoded[col].unique():
df_encoded[col] = df_encoded[col].cat.add_categories('Unknown') if df_encoded[
col].dtype.name == 'category' else \
df_encoded[col]
df_encoded[col] = le.fit_transform(df_encoded[col])
label_encoders[col] = le
print(f"\n🔤 编码 {col}:原始值={le.classes_[:10]}...(显示前10个)")
return df_encoded, label_encoders
def balance_samples(X: pd.DataFrame, y: pd.Series, verify_type: str) -> Tuple[pd.DataFrame, pd.Series]:
"""差异化样本平衡"""
config = get_verify_type_config(verify_type)
smote_strategy = config['smote_strategy']
if y.nunique() < 2:
print(f"\n⚠️ [{verify_type}] 标签只有1类,跳过SMOTE采样")
return X, y
pos_ratio = y.sum() / len(y)
print(f"\n📊 [{verify_type}] 原始样本分布:异常样本占比 {pos_ratio:.2%}")
if pos_ratio < smote_strategy:
try:
smote = SMOTE(random_state=42, sampling_strategy=smote_strategy)
X_resampled, y_resampled = smote.fit_resample(X, y)
new_pos_ratio = y_resampled.sum() / len(y_resampled)
print(f"✅ [{verify_type}] 采样后异常样本占比 {new_pos_ratio:.2%},样本量 {len(X_resampled)}")
return X_resampled, y_resampled
except Exception as e:
print(f"❌ [{verify_type}] SMOTE采样失败:{e}")
return X, y
def create_core_feature_interactions(df: pd.DataFrame, verify_type: str) -> pd.DataFrame:
"""差异化创建核心特征交互项"""
config = get_verify_type_config(verify_type)
core_features = config['core_features']
df_interact = df.copy()
# 根据不同verify_type创建差异化的交互特征
if verify_type == 'swipe':
# 滑动验证:关注滑动相关交互
if all(feat in df_interact.columns for feat in ['load_to_first_delay', 'avg_swipe_interval']):
df_interact['load_swipe_ratio'] = df_interact['load_to_first_delay'] / (
df_interact['avg_swipe_interval'] + 1e-6)
elif verify_type == 'click':
# 点击验证:关注点击相关交互
if all(feat in df_interact.columns for feat in ['total_duration', 'click_interval_cv']):
df_interact['duration_click_cv'] = df_interact['total_duration'] * df_interact['click_interval_cv']
# 通用交互特征
if all(feat in df_interact.columns for feat in ['load_to_first_delay', 'total_duration']):
df_interact['load_to_total_ratio'] = df_interact['load_to_first_delay'] / (df_interact['total_duration'] + 1e-6)
print(f"\n🔧 [{verify_type}] 创建核心特征交互项完成")
return df_interact
# ====================== 模型训练函数(差异化版本) ======================
def tune_hyperparameters(X_train, y_train, X_val, y_val, verify_type: str):
"""差异化调参"""
config = get_verify_type_config(verify_type)
base_params = config['model_params']
def objective(trial):
params = {
'booster': 'gbtree',
'objective': 'binary:logistic',
'eval_metric': ['logloss', 'auc'],
'max_depth': trial.suggest_int('max_depth', base_params['max_depth'] - 1, base_params['max_depth'] + 1),
'learning_rate': trial.suggest_float('learning_rate', base_params['learning_rate'] * 0.8,
base_params['learning_rate'] * 1.2, log=True),
'subsample': trial.suggest_float('subsample', base_params['subsample'] - 0.05,
base_params['subsample'] + 0.05),
'colsample_bytree': trial.suggest_float('colsample_bytree', base_params['colsample_bytree'] - 0.05,
base_params['colsample_bytree'] + 0.05),
'reg_alpha': trial.suggest_float('reg_alpha', base_params['reg_alpha'] * 0.8,
base_params['reg_alpha'] * 1.2),
'reg_lambda': trial.suggest_float('reg_lambda', base_params['reg_lambda'] * 0.8,
base_params['reg_lambda'] * 1.2),
'random_state': 42,
'verbosity': 0
}
dtrain = xgb.DMatrix(X_train, label=y_train)
dval = xgb.DMatrix(X_val, label=y_val)
model = xgb.train(
params, dtrain, num_boost_round=100,
evals=[(dval, 'eval')], early_stopping_rounds=15, verbose_eval=False
)
return model.best_score
study = optuna.create_study(direction='maximize', study_name=f'{verify_type}_xgb_tune')
study.optimize(objective, n_trials=5)
print(f"\n🏆 [{verify_type}] 最优参数:{study.best_params}")
print(f"🏆 [{verify_type}] 最优验证AUC:{study.best_value:.4f}")
return study.best_params
def train_xgboost_model(X_train, y_train, X_test, y_test, verify_type: str, use_optuna: bool = True):
"""差异化训练XGBoost模型"""
# 初始化绘图
init_realtime_plot(verify_type)
# 获取差异化配置
config = get_verify_type_config(verify_type)
model_params = config['model_params']
# 基础参数
base_params = {
'booster': 'gbtree',
'objective': 'binary:logistic',
'eval_metric': ['logloss', 'auc'],
'verbosity': 1,
'random_state': 42,
'scale_pos_weight': 1,
'importance_type': 'gain',
**model_params
}
# 调参
if use_optuna and y_train.nunique() >= 2:
X_train_tune, X_val_tune, y_train_tune, y_val_tune = train_test_split(
X_train, y_train, test_size=0.2, random_state=42, stratify=y_train
)
best_params = tune_hyperparameters(X_train_tune, y_train_tune, X_val_tune, y_val_tune, verify_type)
base_params = {**base_params, **best_params}
# 准备数据
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)
# 保存特征列
preprocessing_params[verify_type]['feature_cols'] = X_train.columns.tolist()
# 创建回调
plot_callback = XGBoostRealtimePlotCallback(verify_type)
# 训练模型
print(f"\n🚀 开始训练 {verify_type} 类型验证码模型...")
evals_result = {}
try:
model = xgb.train(
base_params,
dtrain,
num_boost_round=base_params['n_estimators'],
evals=[(dtrain, 'train'), (dtest, 'eval')],
evals_result=evals_result,
early_stopping_rounds=20,
verbose_eval=10,
callbacks=[plot_callback]
)
except Exception as e:
print(f"\n⚠️ [{verify_type}] 训练异常:{e}")
# 重试(无回调)
model = xgb.train(
base_params,
dtrain,
num_boost_round=base_params['n_estimators'],
evals=[(dtrain, 'train'), (dtest, 'eval')],
evals_result=evals_result,
early_stopping_rounds=20,
verbose_eval=10
)
# 绘制特征分布
try:
if y_test.nunique() >= 2:
ax4 = train_metrics[verify_type]['ax4']
ax4.clear()
ax4.set_title(f'{verify_type} 核心特征分布(正常vs异常)')
X_test_df = pd.DataFrame(X_test, columns=X_train.columns)
normal_mask = y_test == 0
abnormal_mask = y_test == 1
core_features = config['core_features']
positions = list(range(1, len(core_features) + 1))
has_data = False
for i, feat in enumerate(core_features):
if feat in X_test_df.columns:
normal_data = X_test_df.loc[normal_mask, feat].values
abnormal_data = X_test_df.loc[abnormal_mask, feat].values
if len(normal_data) > 0 and len(abnormal_data) > 0:
has_data = True
ax4.boxplot(normal_data, positions=[positions[i] - 0.15], widths=0.2,
patch_artist=True, boxprops=dict(facecolor='green', alpha=0.5),
label='正常' if i == 0 else "")
ax4.boxplot(abnormal_data, positions=[positions[i] + 0.15], widths=0.2,
patch_artist=True, boxprops=dict(facecolor='red', alpha=0.5),
label='异常' if i == 0 else "")
if has_data:
ax4.set_xticks(positions)
ax4.set_xticklabels(core_features)
ax4.legend()
ax4.grid(True, alpha=0.3)
ax4.set_visible(True)
except Exception as e:
print(f"\n⚠️ [{verify_type}] 特征分布绘图失败:{e}")
return model, evals_result, base_params
def evaluate_model(model, X_test, y_test, feature_cols, verify_type: str):
"""差异化评估模型"""
config = get_verify_type_config(verify_type)
core_features = config['core_features']
dtest = xgb.DMatrix(X_test)
y_pred_prob = model.predict(dtest)
y_pred = (y_pred_prob > 0.5).astype(int)
# 计算指标
accuracy = accuracy_score(y_test, y_pred)
auc = roc_auc_score(y_test, y_pred_prob) if y_test.nunique() >= 2 else -1
print(f"\n📈 [{verify_type}] 模型评估结果:")
print(f"🏆 最佳迭代次数:{model.best_iteration}")
print(f"🎯 准确率:{accuracy:.4f}")
print(f"📊 AUC:{auc:.4f}" if auc != -1 else "📊 AUC:无意义")
# 特征重要性
importance_dict = model.get_score(importance_type='gain')
all_importance = {feat: importance_dict.get(feat, 0) for feat in feature_cols}
importance_df = pd.DataFrame({
'feature': list(all_importance.keys()),
'importance': list(all_importance.values())
}).sort_values('importance', ascending=False)
# 标记核心特征
importance_df['is_core'] = importance_df['feature'].apply(lambda x: '⭐' if x in core_features else '')
# 核心特征贡献
total_importance = importance_df['importance'].sum()
core_total = importance_df[importance_df['feature'].isin(core_features)]['importance'].sum()
core_contribution = (core_total / total_importance * 100) if total_importance > 0 else 0
print(f"\n🎯 [{verify_type}] 核心特征总贡献:{core_contribution:.1f}%")
print(f"\n✨ [{verify_type}] 特征重要性(前10):")
print(importance_df.head(10)[['is_core', 'feature', 'importance']])
return {
'accuracy': accuracy, 'auc': auc,
'feature_importance': importance_df,
'core_feature_contribution': core_contribution,
'best_iteration': model.best_iteration,
'feature_importance_dict': all_importance # 新增:便于JSON序列化
}
def save_results(models: dict, label_encoders: dict, feature_cols: dict,
eval_results: dict, params: dict, save_dir='./captcha_results'):
"""保存所有模型结果(包含JSON格式)"""
# 创建目录
os.makedirs(save_dir, exist_ok=True)
os.makedirs('./captcha_models', exist_ok=True)
# 保存每个类型的模型
for verify_type in ['swipe', 'click']:
if verify_type in models:
# 保存模型二进制文件
model_bin_path = f'./captcha_models/captcha_xgb_model_{verify_type}.bin'
models[verify_type].save_model(model_bin_path)
print(f"\n💾 {verify_type} 模型已保存到:{model_bin_path}")
# ========== 新增:转换模型为JSON格式 ==========
model_json_path = f'./captcha_models/captcha_xgb_model_{verify_type}.json'
convert_xgb_model_to_json(models[verify_type], model_json_path)
# 保存预处理参数(pickle格式)
preprocess_path = f'./captcha_models/preprocessing_params_{verify_type}.pkl'
with open(preprocess_path, 'wb') as f:
pickle.dump(preprocessing_params[verify_type], f)
# ========== 完善JSON配置文件 ==========
# 1. 准备JSON序列化的数据
json_data = {
'verify_type': verify_type,
'model_config': params.get(verify_type, {}),
'preprocessing_params': preprocessing_params.get(verify_type, {}),
'evaluation_metrics': {
'accuracy': float(eval_results[verify_type]['accuracy']),
'auc': float(eval_results[verify_type]['auc']),
'core_feature_contribution': float(eval_results[verify_type]['core_feature_contribution']),
'best_iteration': int(eval_results[verify_type]['best_iteration'])
},
'feature_importance': eval_results[verify_type]['feature_importance_dict'],
'core_features': get_verify_type_config(verify_type)['core_features'],
'feature_columns': feature_cols.get(verify_type, []),
'training_timestamp': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
'numeric_medians': preprocessing_params[verify_type].get('numeric_medians', {}),
'categorical_modes': preprocessing_params[verify_type].get('categorical_modes', {})
}
# 2. 处理LabelEncoder(只保存类别映射)
le_json = {}
if verify_type in label_encoders:
for col, le in label_encoders[verify_type].items():
le_json[col] = {
'classes': le.classes_.tolist(),
'mapping': {str(cls): int(idx) for idx, cls in enumerate(le.classes_)},
'num_classes': len(le.classes_)
}
json_data['label_encoders'] = le_json
# 3. 保存JSON配置文件
json_config_path = f'./captcha_models/captcha_model_{verify_type}.json'
with open(json_config_path, 'w', encoding='utf-8') as f:
json.dump(json_data, f, ensure_ascii=False, indent=4)
print(f"💾 {verify_type} 模型JSON配置文件已保存到:{json_config_path}")
# 保存评估结果(CSV格式)
eval_df = pd.DataFrame({
'metric': ['accuracy', 'auc', 'core_feature_contribution', 'best_iteration'],
'value': [
eval_results[verify_type]['accuracy'],
eval_results[verify_type]['auc'],
eval_results[verify_type]['core_feature_contribution'],
eval_results[verify_type]['best_iteration']
]
})
eval_df.to_csv(f'{save_dir}/evaluation_results_{verify_type}.csv', index=False)
# 保存全局配置JSON文件
global_config = {
'verify_type_configs': VERIFY_TYPE_CONFIGS,
'useless_columns': USELESS_COLS,
'training_timestamp': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
}
global_config_path = './captcha_models/global_config.json'
with open(global_config_path, 'w', encoding='utf-8') as f:
json.dump(global_config, f, ensure_ascii=False, indent=4)
print(f"💾 全局配置JSON文件已保存到:{global_config_path}")
# 保存标签编码器(pickle格式)
le_path = './captcha_models/label_encoders.pkl'
with open(le_path, 'wb') as f:
pickle.dump(label_encoders, f)
print(f"\n✅ 所有模型和配置已保存完成!")
# ====================== 统一预测接口(兼容JSON配置) ======================
def predict_captcha(input_data: dict):
"""
统一的预测接口,自动识别verify_type并使用对应模型
"""
# 获取验证类型
verify_type = input_data.get('verify_type', 'default').lower()
if verify_type not in ['swipe', 'click']:
verify_type = 'default'
# 加载对应模型
model_path = f'./captcha_models/captcha_xgb_model_{verify_type}.bin'
preprocess_path = f'./captcha_models/preprocessing_params_{verify_type}.pkl'
json_config_path = f'./captcha_models/captcha_model_{verify_type}.json'
if not os.path.exists(model_path):
raise Exception(f"❌ 未找到 {verify_type} 类型的模型文件")
# 加载模型
model = xgb.Booster()
model.load_model(model_path)
# 优先从JSON加载配置
if os.path.exists(json_config_path):
with open(json_config_path, 'r', encoding='utf-8') as f:
json_config = json.load(f)
preprocess_params = json_config.get('preprocessing_params', {})
feature_cols = json_config.get('feature_columns', [])
label_encoders_json = json_config.get('label_encoders', {})
else:
# 兼容旧的pickle格式
with open(preprocess_path, 'rb') as f:
preprocess_params = pickle.load(f)
feature_cols = preprocess_params.get('feature_cols', [])
le_path = './captcha_models/label_encoders.pkl'
with open(le_path, 'rb') as f:
label_encoders = pickle.load(f)
label_encoders_json = {}
# 数据预处理(与训练时一致)
df = pd.DataFrame([input_data])
# 过滤无用列
df = filter_useless_columns(df)
# 处理数值特征
for col in df.columns:
if df[col].dtype in ['int64', 'float64'] and df[col].iloc[0] < 0:
df[col] = 0
# 编码分类特征(兼容JSON格式的编码器配置)
for col in ['os', 'device_type', 'ip_type']:
if col in df.columns and col in label_encoders_json:
le_mapping = label_encoders_json[col]['mapping']
df[col] = df[col].fillna('Unknown').astype(str)
# 处理未知类别
df[col] = df[col].apply(lambda x: le_mapping.get(x, le_mapping.get('Unknown', 0)))
# 只保留训练时的特征列
for col in feature_cols:
if col not in df.columns:
df[col] = 0
df = df[feature_cols]
# 预测
dmatrix = xgb.DMatrix(df)
pred_prob = model.predict(dmatrix)[0]
pred_label = 1 if pred_prob > 0.5 else 0
return {
'verify_type': verify_type,
'abnormal_probability': float(pred_prob),
'predicted_label': int(pred_label),
'label_description': '异常/疑似' if pred_label == 1 else '正常',
'prediction_timestamp': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
}
# ====================== 主函数 ======================
def main():
"""主函数:差异化训练不同verify_type的模型"""
# 配置
CSV_FILE_PATH = './captcha_verify_feature.csv'
USE_OPTUNA = False
try:
# 1. 加载数据
df = load_and_validate_data(CSV_FILE_PATH)
# 2. 按verify_type拆分数据
swipe_df = df[df['verify_type'] == 'swipe'].copy()
click_df = df[df['verify_type'] == 'click'].copy()
# 存储结果的字典
models = {}
eval_results = {}
all_params = {}
label_encoders = {}
feature_cols_dict = {}
# 3. 训练swipe模型
if len(swipe_df) > 0:
print("\n" + "=" * 60)
print("开始训练 SWIPE (滑动) 验证模型")
print("=" * 60)
# 数据预处理
swipe_df = process_labels(swipe_df)
swipe_df = filter_extreme_samples(swipe_df, 'swipe', ['load_to_first_delay', 'total_duration'])
# 特征处理
feature_cols, categorical_cols, numeric_cols = select_features(swipe_df, 'swipe')
swipe_df = fix_numeric_outliers(swipe_df, 'swipe', numeric_cols)
swipe_df = handle_missing_values(swipe_df, 'swipe', numeric_cols, categorical_cols)
swipe_df = create_core_feature_interactions(swipe_df, 'swipe')
# 更新特征列表
feature_cols = [col for col in swipe_df.columns if col not in ['label', 'binary_label', 'verify_type']]
swipe_df_encoded, le = encode_categorical_features(swipe_df, categorical_cols)
label_encoders['swipe'] = le
# 分离特征和标签
X = swipe_df_encoded[feature_cols]
y = swipe_df_encoded['binary_label']
# 平衡样本
X_balanced, y_balanced = balance_samples(X, y, 'swipe')
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(
X_balanced, y_balanced, test_size=0.2, random_state=42,
stratify=y_balanced if y_balanced.nunique() >= 2 else None
)
# 训练模型
model, evals_result, params = train_xgboost_model(
X_train, y_train, X_test, y_test, 'swipe', USE_OPTUNA
)
# 评估模型
eval_res = evaluate_model(model, X_test, y_test, feature_cols, 'swipe')
# 保存结果
models['swipe'] = model
eval_results['swipe'] = eval_res
all_params['swipe'] = params
feature_cols_dict['swipe'] = feature_cols
# 4. 训练click模型
if len(click_df) > 0:
print("\n" + "=" * 60)
print("开始训练 CLICK (点击) 验证模型")
print("=" * 60)
# 数据预处理
click_df = process_labels(click_df)
click_df = filter_extreme_samples(click_df, 'click', ['load_to_first_delay', 'total_duration'])
# 特征处理
feature_cols, categorical_cols, numeric_cols = select_features(click_df, 'click')
click_df = fix_numeric_outliers(click_df, 'click', numeric_cols)
click_df = handle_missing_values(click_df, 'click', numeric_cols, categorical_cols)
click_df = create_core_feature_interactions(click_df, 'click')
# 更新特征列表
feature_cols = [col for col in click_df.columns if col not in ['label', 'binary_label', 'verify_type']]
click_df_encoded, le = encode_categorical_features(click_df, categorical_cols)
label_encoders['click'] = le
# 分离特征和标签
X = click_df_encoded[feature_cols]
y = click_df_encoded['binary_label']
# 平衡样本
X_balanced, y_balanced = balance_samples(X, y, 'click')
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(
X_balanced, y_balanced, test_size=0.2, random_state=42,
stratify=y_balanced if y_balanced.nunique() >= 2 else None
)
# 训练模型
model, evals_result, params = train_xgboost_model(
X_train, y_train, X_test, y_test, 'click', USE_OPTUNA
)
# 评估模型
eval_res = evaluate_model(model, X_test, y_test, feature_cols, 'click')
# 保存结果
models['click'] = model
eval_results['click'] = eval_res
all_params['click'] = params
feature_cols_dict['click'] = feature_cols
# 5. 保存所有结果(包含JSON格式)
save_results(models, label_encoders, feature_cols_dict, eval_results, all_params)
# 6. 测试预测
print("\n" + "=" * 60)
print("测试差异化预测功能")
print("=" * 60)
# 测试滑动验证
test_swipe_data = {
'verify_type': 'swipe',
'os': 'Android',
'device_type': 'Mobile',
'ip_type': '电信',
'load_to_first_delay': 20000,
'total_duration': 30000,
'avg_swipe_interval': 100,
'ip_verify_freq': 10
}
# 测试点击验证
test_click_data = {
'verify_type': 'click',
'os': 'iOS',
'device_type': 'Mobile',
'ip_type': '联通',
'load_to_first_delay': 15000,
'total_duration': 25000,
'click_interval_cv': 0.5,
'ip_verify_freq': 8
}
# 预测(仅当模型存在时)
if 'swipe' in models:
swipe_result = predict_captcha(test_swipe_data)
print(f"\n📱 滑动验证预测结果:")
print(f" 验证类型:{swipe_result['verify_type']}")
print(f" 异常概率:{swipe_result['abnormal_probability']:.4f}")
print(f" 预测结果:{swipe_result['label_description']}")
print(f" 预测时间:{swipe_result['prediction_timestamp']}")
if 'click' in models:
click_result = predict_captcha(test_click_data)
print(f"\n🖱️ 点击验证预测结果:")
print(f" 验证类型:{click_result['verify_type']}")
print(f" 异常概率:{click_result['abnormal_probability']:.4f}")
print(f" 预测结果:{click_result['label_description']}")
print(f" 预测时间:{click_result['prediction_timestamp']}")
print("\n🎉 差异化模型训练和预测完成!")
# 保持绘图窗口
plt.ioff()
plt.show()
except Exception as e:
print(f"\n❌ 训练失败:{str(e)}")
import traceback
traceback.print_exc()
plt.ioff()
raise
if __name__ == "__main__":
main()
推理
/**预测模型**/
var CaptchaPredict *CaptchaPredictor
// ========== 1. 适配Python生成的JSON配置结构体(补充中位数/众数字段) ==========
type ModelJSONConfig struct {
VerifyType string `json:"verify_type"`
ModelConfig map[string]interface{} `json:"model_config"`
PreprocessingParams map[string]interface{} `json:"preprocessing_params"`
EvaluationMetrics struct {
Accuracy float64 `json:"accuracy"`
Auc float64 `json:"auc"`
CoreFeatureContribution float64 `json:"core_feature_contribution"`
BestIteration int `json:"best_iteration"`
} `json:"evaluation_metrics"`
FeatureImportance map[string]float64 `json:"feature_importance"`
CoreFeatures []string `json:"core_features"`
FeatureColumns []string `json:"feature_columns"`
TrainingTimestamp string `json:"training_timestamp"`
LabelEncoders map[string]struct {
Classes []string `json:"classes"`
Mapping map[string]int `json:"mapping"`
NumClasses int `json:"num_classes"`
} `json:"label_encoders"`
NumericMedians map[string]float64 `json:"numeric_medians"` // 数值特征中位数(填充缺失值)
CategoricalModes map[string]string `json:"categorical_modes"` // 分类特征众数(填充缺失值)
}
// ========== 2. XGBoost模型结构体(适配2.x原生JSON格式) ==========
type XGBoostModel struct {
Version []int `json:"version"`
Learner struct {
Attributes struct {
BestIteration string `json:"best_iteration"`
BestScore string `json:"best_score"`
} `json:"attributes"`
FeatureNames []string `json:"feature_names"`
FeatureTypes []string `json:"feature_types"`
GradientBooster struct {
Model struct {
GBTreeModelParam struct {
NumParallelTree string `json:"num_parallel_tree"`
NumTrees string `json:"num_trees"`
} `json:"gbtree_model_param"`
IterationIndptr []int `json:"iteration_indptr"`
TreeInfo []int `json:"tree_info"`
Trees []Tree `json:"trees"` // 树数组
} `json:"model"`
TreeParam struct {
NumFeature string `json:"num_feature"`
} `json:"tree_param"`
} `json:"gradient_booster"`
LearnerModelParam struct {
BaseScore string `json:"base_score"`
NumClass string `json:"num_class"`
NumFeature string `json:"num_feature"`
NumTarget string `json:"num_target"`
} `json:"learner_model_param"`
Objective struct {
Name string `json:"name"`
} `json:"objective"`
} `json:"learner"`
}
// Tree 适配XGBoost 2.x原生JSON格式的树结构
type Tree struct {
BaseWeights []float64 `json:"base_weights"`
Categories []string `json:"categories"`
CategoriesNodes []int `json:"categories_nodes"`
CategoriesSegments []int `json:"categories_segments"`
CategoriesSizes []int `json:"categories_sizes"`
DefaultLeft []int `json:"default_left"`
ID int `json:"id"`
LeftChildren []int `json:"left_children"`
LossChanges []float64 `json:"loss_changes"`
Parents []int `json:"parents"`
RightChildren []int `json:"right_children"`
SplitConditions []float64 `json:"split_conditions"`
SplitIndices []int `json:"split_indices"`
SplitType []int `json:"split_type"`
SumHessian []float64 `json:"sum_hessian"`
TreeParam struct {
NumDeleted string `json:"num_deleted"`
NumFeature string `json:"num_feature"`
NumNodes string `json:"num_nodes"`
SizeLeafVector string `json:"size_leaf_vector"`
} `json:"tree_param"`
}
// ========== 3. 差异化预测器结构体 ==========
type CaptchaPredictor struct {
// 模型根路径
modelRootPath string
// 加载的模型和配置
models map[string]*XGBoostModel
configs map[string]*ModelJSONConfig
// 初始化状态
initialized bool
}
// ========== 4. 预测结果结构体 ==========
type PredictResult struct {
VerifyType string `json:"verify_type"`
AbnormalProbability float64 `json:"abnormal_probability"`
PredictedLabel int `json:"predicted_label"`
LabelDescription string `json:"label_description"`
PredictionTimestamp string `json:"prediction_timestamp"`
}
// ========== 5. 构造函数:创建差异化预测器 ==========
func NewCaptchaPredictor(modelRootPath string) (*CaptchaPredictor, error) {
predictor := &CaptchaPredictor{
modelRootPath: modelRootPath,
models: make(map[string]*XGBoostModel),
configs: make(map[string]*ModelJSONConfig),
initialized: false,
}
// 加载swipe和click两种模型
verifyTypes := []string{"swipe", "click"}
for _, vt := range verifyTypes {
// 加载JSON配置(Python生成的captcha_model_{vt}.json)
configPath := fmt.Sprintf("%s/captcha_model_%s.json", modelRootPath, vt)
config, err := loadModelConfig(configPath)
if err != nil {
log.Printf("⚠️ 加载%s配置失败: %v,跳过该模型", vt, err)
continue
}
predictor.configs[vt] = config
// 加载XGBoost JSON模型(Python转换的captcha_xgb_model_{vt}.json)
modelPath := fmt.Sprintf("%s/captcha_xgb_model_%s.json", modelRootPath, vt)
model, err := loadXGBoostModel(modelPath)
if err != nil {
log.Printf("⚠️ 加载%s模型失败: %v,跳过该模型", vt, err)
delete(predictor.configs, vt)
continue
}
predictor.models[vt] = model
log.Printf("✅ 成功加载%s模型:特征数=%d,决策树数=%d",
vt, len(config.FeatureColumns), len(model.Learner.GradientBooster.Model.Trees))
}
if len(predictor.models) == 0 {
return nil, fmt.Errorf("未加载到任何验证码模型")
}
predictor.initialized = true
log.Println("✅ 验证码预测器初始化完成")
return predictor, nil
}
// ========== 6. 内部加载方法 ==========
func loadModelConfig(configPath string) (*ModelJSONConfig, error) {
data, err := os.ReadFile(configPath)
if err != nil {
return nil, err
}
var config ModelJSONConfig
if err := json.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("解析配置失败: %v", err)
}
return &config, nil
}
func loadXGBoostModel(modelPath string) (*XGBoostModel, error) {
data, err := os.ReadFile(modelPath)
if err != nil {
return nil, err
}
var model XGBoostModel
if err := json.Unmarshal(data, &model); err != nil {
return nil, fmt.Errorf("解析模型失败: %v", err)
}
return &model, nil
}
// ========== 7. 预处理方法(完全匹配Python逻辑) ==========
func (p *CaptchaPredictor) preprocess(rawData map[string]interface{}, verifyType string) ([]float64, error) {
if !p.initialized {
return nil, fmt.Errorf("预测器未初始化")
}
// 获取对应配置
config, ok := p.configs[verifyType]
if !ok {
return nil, fmt.Errorf("未找到%s类型的配置", verifyType)
}
// 初始化特征数组
featureCols := config.FeatureColumns
features := make([]float64, len(featureCols))
// 构建特征索引映射
featureIndex := make(map[string]int, len(featureCols))
for idx, col := range featureCols {
featureIndex[col] = idx
}
// 过滤无用列(完全匹配Python的USELESS_COLS)
uselessCols := map[string]bool{
"id": true,
"sample_id": true,
"session_id": true,
"ip": true,
"device_fingerprint": true,
"temp_score": true,
"created_at": true,
"updated_at": true,
"user_id": true,
"device_id": true,
"phone_number": true,
"email": true,
"cookie_id": true,
"session_token": true,
}
for col, idx := range featureIndex {
// 跳过无用列
if uselessCols[col] {
features[idx] = 0.0
continue
}
val, exists := rawData[col]
// ========== 关键修改1:缺失值填充(匹配Python逻辑) ==========
if !exists || val == nil {
// 分类特征:使用众数填充
if le, ok := config.LabelEncoders[col]; ok {
modeVal := config.CategoricalModes[col]
if modeVal == "" {
modeVal = "Unknown"
}
code, ok := le.Mapping[modeVal]
if !ok {
code = le.Mapping["Unknown"]
}
features[idx] = float64(code)
log.Printf("⚠️ 特征[%s]缺失,填充众数[%s]编码: %d", col, modeVal, code)
continue
}
// 数值特征:使用中位数填充
medianVal, ok := config.NumericMedians[col]
if !ok {
// 核心特征特殊处理
if strings.Contains(col, "delay") || strings.Contains(col, "duration") {
medianVal = 0.0
} else if strings.Contains(col, "freq") {
medianVal = 0.0
} else {
medianVal = 0.0
}
}
features[idx] = medianVal
//log.Printf("⚠️ 特征[%s]缺失,填充中位数: %.2f", col, medianVal)
continue
}
// ========== 关键修改2:分类特征处理(完全匹配Python) ==========
if le, ok := config.LabelEncoders[col]; ok {
var strVal string
switch v := val.(type) {
case string:
strVal = v
case nil:
strVal = "Unknown"
default:
strVal = fmt.Sprintf("%v", v)
}
// 使用编码映射
code, ok := le.Mapping[strVal]
if !ok {
// 未知值使用Unknown的编码
code = le.Mapping["Unknown"]
log.Printf("⚠️ 特征[%s]未知值: %s,使用Unknown编码: %d", col, strVal, code)
}
features[idx] = float64(code)
continue
}
// ========== 关键修改3:数值特征处理(完全匹配Python) ==========
var numVal float64
switch v := val.(type) {
case float64:
numVal = v
case int:
numVal = float64(v)
case int64:
numVal = float64(v)
case string:
n, err := strconv.ParseFloat(v, 64)
if err != nil {
// 解析失败:使用中位数填充
medianVal, ok := config.NumericMedians[col]
if !ok {
medianVal = 0.0
}
numVal = medianVal
log.Printf("⚠️ 特征[%s]值错误: %s,填充中位数: %.2f", col, v, medianVal)
} else {
numVal = n
}
default:
// 类型错误:使用中位数填充
medianVal, ok := config.NumericMedians[col]
if !ok {
medianVal = 0.0
}
numVal = medianVal
log.Printf("⚠️ 特征[%s]类型错误: %T,填充中位数: %.2f", col, v, medianVal)
}
// 负数修复(匹配Python逻辑:所有负数替换为0)
if numVal < 0 {
numVal = 0.0
log.Printf("⚠️ 特征[%s]负数修复为0(原值: %.2f)", col, numVal)
}
// ========== 关键修改4:极端值过滤(使用Python保存的分位数) ==========
if colParams, ok := config.PreprocessingParams[col].(map[string]interface{}); ok {
// 下界过滤
if qLow, ok := colParams["q_low"].(float64); ok && numVal < qLow {
log.Printf("⚠️ 特征[%s]低于下界%.2f,截断为%.2f(原值: %.2f)", col, qLow, qLow, numVal)
numVal = qLow
}
// 上界过滤
if qHigh, ok := colParams["q_high"].(float64); ok && numVal > qHigh {
log.Printf("⚠️ 特征[%s]高于上界%.2f,截断为%.2f(原值: %.2f)", col, qHigh, qHigh, numVal)
numVal = qHigh
}
}
// ========== 关键修改5:特征交互项计算(匹配Python) ==========
// 先保存原始值,后续计算交互项
features[idx] = numVal
}
// 计算特征交互项(完全匹配Python逻辑)
p.calculateFeatureInteractions(features, featureIndex, rawData, verifyType)
return features, nil
}
// calculateFeatureInteractions 计算特征交互项(匹配Python的create_core_feature_interactions)
func (p *CaptchaPredictor) calculateFeatureInteractions(features []float64, featureIndex map[string]int, rawData map[string]interface{}, verifyType string) {
// 通用交互项:load_to_total_ratio = load_to_first_delay / (total_duration + 1e-6)
if loadIdx, ok := featureIndex["load_to_first_delay"]; ok {
if totalIdx, ok := featureIndex["total_duration"]; ok {
if ratioIdx, ok := featureIndex["load_to_total_ratio"]; ok {
loadVal := features[loadIdx]
totalVal := features[totalIdx]
ratio := loadVal / (totalVal + 1e-6)
features[ratioIdx] = ratio
log.Printf("✅ 计算通用交互项 load_to_total_ratio = %.4f", ratio)
}
}
}
// 差异化交互项
switch verifyType {
case "swipe":
// 滑动验证:load_swipe_ratio = load_to_first_delay / (avg_swipe_interval + 1e-6)
if loadIdx, ok := featureIndex["load_to_first_delay"]; ok {
if swipeIdx, ok := featureIndex["avg_swipe_interval"]; ok {
if ratioIdx, ok := featureIndex["load_swipe_ratio"]; ok {
loadVal := features[loadIdx]
swipeVal := features[swipeIdx]
ratio := loadVal / (swipeVal + 1e-6)
features[ratioIdx] = ratio
log.Printf("✅ 计算滑动交互项 load_swipe_ratio = %.4f", ratio)
}
}
}
case "click":
// 点击验证:duration_click_cv = total_duration * click_interval_cv
if totalIdx, ok := featureIndex["total_duration"]; ok {
if clickIdx, ok := featureIndex["click_interval_cv"]; ok {
if cvIdx, ok := featureIndex["duration_click_cv"]; ok {
totalVal := features[totalIdx]
clickVal := features[clickIdx]
product := totalVal * clickVal
features[cvIdx] = product
log.Printf("✅ 计算点击交互项 duration_click_cv = %.4f", product)
}
}
}
}
}
// ========== 8. 预测方法(优化XGBoost推理逻辑) ==========
func (p *CaptchaPredictor) predict(features []float64, verifyType string) (float64, int, error) {
if !p.initialized {
return 0, 0, fmt.Errorf("预测器未初始化")
}
// 获取对应模型和配置
model, ok := p.models[verifyType]
if !ok {
return 0, 0, fmt.Errorf("未找到%s类型的模型", verifyType)
}
config, ok := p.configs[verifyType]
if !ok {
return 0, 0, fmt.Errorf("未找到%s类型的配置", verifyType)
}
// ========== 关键修改6:正确解析BaseScore ==========
actualBaseScore := 0.5
if model.Learner.LearnerModelParam.BaseScore != "" {
bs, err := strconv.ParseFloat(model.Learner.LearnerModelParam.BaseScore, 64)
if err == nil {
actualBaseScore = bs
}
}
// 检查特征长度
featureCols := config.FeatureColumns
if len(features) != len(featureCols) {
return 0, 0, fmt.Errorf("特征长度不匹配:期望%d,实际%d", len(featureCols), len(features))
}
// 获取所有决策树
trees := model.Learner.GradientBooster.Model.Trees
if len(trees) == 0 {
log.Println("⚠️ 模型中未找到决策树,返回默认概率0.5")
return 0.5, 0, nil
}
// ========== 关键修改7:优化树遍历逻辑 ==========
totalScore := 0.0
validTreeCount := 0
for treeIdx, tree := range trees {
currentNodeID := 0 // 从根节点开始
nodeCount, err := strconv.Atoi(tree.TreeParam.NumNodes)
if err != nil || nodeCount == 0 {
log.Printf("⚠️ 树%d节点数无效: %s,跳过", treeIdx, tree.TreeParam.NumNodes)
continue
}
// 遍历节点直到叶节点
for {
// 边界检查
if currentNodeID < 0 || currentNodeID >= len(tree.LeftChildren) {
break
}
// 检查是否是叶节点(左子节点为-1)
leftChild := tree.LeftChildren[currentNodeID]
if leftChild == -1 {
// 累加叶节点权重
if currentNodeID < len(tree.BaseWeights) {
totalScore += tree.BaseWeights[currentNodeID]
validTreeCount++
}
break
}
// 获取分裂特征和阈值
splitFeatIdx := tree.SplitIndices[currentNodeID]
splitThreshold := tree.SplitConditions[currentNodeID]
// 特征索引检查
if splitFeatIdx < 0 || splitFeatIdx >= len(features) {
// 特征不存在,走右分支
currentNodeID = tree.RightChildren[currentNodeID]
continue
}
// 判断分支(匹配XGBoost原生逻辑)
featVal := features[splitFeatIdx]
if math.IsNaN(featVal) || math.IsInf(featVal, 0) || featVal <= splitThreshold {
currentNodeID = leftChild
} else {
currentNodeID = tree.RightChildren[currentNodeID]
}
}
}
// ========== 关键修改8:正确计算逻辑回归概率 ==========
// XGBoost二分类公式:prob = 1 / (1 + exp(-(base_score_log_odds + sum(tree_scores))))
baseLogOdds := math.Log(actualBaseScore / (1 - actualBaseScore))
finalLogOdds := baseLogOdds + totalScore
predProb := 1.0 / (1.0 + math.Exp(-finalLogOdds))
// 生成标签(阈值0.5,匹配Python)
predLabel := 0
if predProb > 0.5 {
predLabel = 1
}
log.Printf("✅ %s模型推理完成:有效树数=%d,BaseScore=%.4f,总得分=%.4f,概率=%.4f,标签=%d",
verifyType, validTreeCount, actualBaseScore, totalScore, predProb, predLabel)
return predProb, predLabel, nil
}
// ========== 9. 统一预测接口(完全匹配Python的predict_captcha) ==========
func (p *CaptchaPredictor) PredictCaptcha(rawData map[string]interface{}) (*PredictResult, error) {
if !p.initialized {
return nil, fmt.Errorf("预测器未初始化")
}
// 获取验证类型
verifyType, ok := rawData["verify_type"].(string)
if !ok {
verifyType = "default"
}
verifyType = strings.ToLower(verifyType)
// 处理默认类型(匹配Python逻辑)
if verifyType != "swipe" && verifyType != "click" {
// 优先使用swipe模型,没有则使用click
if _, ok := p.models["swipe"]; ok {
verifyType = "swipe"
} else if _, ok := p.models["click"]; ok {
verifyType = "click"
} else {
return nil, fmt.Errorf("无可用的验证码模型")
}
log.Printf("⚠️ 验证类型%s不支持,自动切换为%s", rawData["verify_type"], verifyType)
}
// 预处理(完全匹配Python)
features, err := p.preprocess(rawData, verifyType)
if err != nil {
return nil, fmt.Errorf("预处理失败: %v", err)
}
// 预测
predProb, predLabel, err := p.predict(features, verifyType)
if err != nil {
return nil, fmt.Errorf("预测失败: %v", err)
}
// 构建结果(完全匹配Python返回格式)
result := &PredictResult{
VerifyType: verifyType,
AbnormalProbability: predProb,
PredictedLabel: predLabel,
PredictionTimestamp: time.Now().Format("2006-01-02 15:04:05"),
}
if predLabel == 1 {
result.LabelDescription = "异常/疑似"
} else {
result.LabelDescription = "正常"
}
return result, nil
}

