(19)国家知识产权局
(12)发明 专利申请
(10)申请公布号
(43)申请公布日
(21)申请 号 202210908877.5
(22)申请日 2022.07.29
(71)申请人 平安科技 (深圳) 有限公司
地址 518000 广东省深圳市福田区福田街
道福安社区益田路5033号平 安金融中
心23楼
(72)发明人 李泽远 王健宗 曹康养
(74)专利代理 机构 广州嘉权专利商标事务所有
限公司 4 4205
专利代理师 廖慧贤
(51)Int.Cl.
G06V 10/44(2022.01)
G06V 10/762(2022.01)
G06V 10/764(2022.01)
G06V 10/82(2022.01)G06N 3/04(2006.01)
G06N 3/08(2006.01)
G06N 20/00(2019.01)
(54)发明名称
模型训练方法和装置、 电子设备、 存 储介质
(57)摘要
本申请实施例提供了模 型训练方法和装置、
电子设备、 存储介质, 属于 人工智能技术领域。 该
模型训练方法包括: 获取待训练的原始图像数
据, 并获取服务器端发送的原始模 型的原始训练
参数; 通过原始模型对原始图像数据进行标签预
测得到预测标签, 通过原始模型对原始图像数据
进行特征提取得到初步图像特征; 对初步图像特
征进行投影聚类处理, 得到域标签和域组合标
签; 根据域标签和域组合标签计算距离损失函
数; 根据预设的交叉熵损失函数和距离损失函数
更新第一、 二原始网络参数得到第一、 二目标参
数; 将第一、 二目标参数 发送给服务器端, 使 服务
器端更新原始模 型得到全局模型。 本申请实施例
基于联邦学习进行模型训练可以提高模型的训
练效率和准确率。
权利要求书3页 说明书15页 附图7页
CN 115205546 A
2022.10.18
CN 115205546 A
1.一种模型训练方法, 应用于客户端, 其特 征在于, 所述模型训练方法包括:
获取待训练的原始图像数据, 并获取服务器端发送的原始模型的原始训练参数; 其中,
所述原始图像数据是无标注数据, 所述原始训练参数包括所述原始模型的第一原始网络参
数、 第二原 始网络参数、 聚类数量;
将所述原始图像数据输入所述原始模型, 通过所述原始模型对所述原始图像数据进行
标签预测得到预测标签, 通过所述原始模型对 所述原始图像数据进行特征提取得到初步图
像特征;
对所述初步图像特征进行投影聚类处理, 得到域标签和域组合标签; 其中, 域标签和域
组合标签的数量均为第一数量, 且所述第一数量 等于所述聚类数量;
根据所述 域标签和所述 域组合标签 计算距离损失函数;
根据预设的交叉熵损失函数和所述距离损失函数更新所述第一原始网络参数得到第
一目标参数, 根据所述交叉熵损失函数和所述距离损失函数更新所述第二原始网络参数得
到第二目标参数; 其中, 所述交叉熵损失函数由所述预测标签进行 预先构建得到;
将所述第一目标参数和所述第二目标参数发送给所述服务器端; 其中, 所述第一目标
参数和所述第二目标参数用于所述 服务器端更新所述原 始模型得到全局模型。
2.根据权利要求1所述的模型训练方法, 其特征在于, 所述原始模型包括图卷积网络和
卷积神经网络, 所述通过所述原始模型对所述原始图像数据进行标签预测得到预测标签,
包括:
通过所述图卷积网络对所述原 始图像数据进行参数提取, 得到初步 参数;
通过所述卷积神经网络对所述初步 参数进行 特征图提取, 得到所述初步特 征图;
通过预设的激活函数对所述初步特 征图进行 标签预测, 得到所述预测标签。
3.根据权利要求1所述的模型训练方法, 其特征在于, 所述对所述初步图像特征进行投
影聚类处 理, 得到域标签和域组合标签, 包括:
将所述初步图像特 征映射到同一维度, 得到投影特 征;
通过预设的聚类算法对所述投影特征进行聚类处理, 得到所述域标签和所述域组合标
签。
4.根据权利要求3所述的模型训练方法, 其特征在于, 所述通过预设的聚类算法对所述
投影特征进行聚类处 理, 得到所述 域标签和所述 域组合标签, 包括:
通过k‑means++算法对所述投影特 征进行聚类处 理, 得到所述第一数量的聚类中心;
从所述第一数量的聚类中心中选择一个作为 参考聚类中心;
计算每一所述投影特 征与所述 参考聚类中心的距离, 得到聚类距离;
根据所述聚类距离和预设系数计算得到所述第二数量的域组合标签。
5.根据权利要求4所述的模型训练方法, 其特征在于, 所述域组合标签包括所述预设系
数, 所述根据所述 域标签和所述 域组合标签 计算距离损失函数:
根据所述 域组合标签获取 所述预设系数的最大值, 得到目标系数;
根据所述目标系数从所述第 一数量的聚类中心筛选出目标中心, 从所述第 一数量的聚
类中心过滤所述目标中心得到第二数量的当前聚类中心; 其中, 所述第二数量等于所述第
一数量减1;
计算所述投影特征与 所述目标中心之间的距离得到第 一距离, 并计算所述投影特征与权 利 要 求 书 1/3 页
2
CN 115205546 A
2每一所述当前聚类中心之间的距离得到所述第二数量的第二距离;
将所述第一距离与每一所述第二距离进行求差计算, 得到所述第二数量的距离 差;
根据所述第二数量的距离 差计算所述距离损失函数。
6.根据权利要求1至5任一项所述的模型训练方法, 其特征在于, 所述方法还包括: 构建
所述交叉熵损失函数, 具体包括:
获取所述原始图像数据的原 始标签; 所述原 始标签是 所述原始图像数据的真实标签;
将所述预测标签与所述真实标签进行比对, 得到标签比对结果;
根据所述预测标签和所述对比结果构建所述交叉熵损失函数。
7.一种模型训练方法, 应用于服 务器端, 其特 征在于, 所述模型训练方法包括:
向客户端发送原始模型的训练参数; 其中, 所述原始训练参数包括所述原始模型的第
一原始网络参数、 第二原 始网络参数;
获取所述客户端对所述第一原始网络参数更新得到的第一目标参数和对所述第二原
始网络参数更新得到的第二 目标参数; 其中, 所述第一 目标参数和所述第二 目标参数是根
据如权利要求1至 6任一项所述的模型训练方法训练得到;
将所述第一目标参数和所述第二目标参数进行整合处 理, 得到全局模型参数;
根据所述全局模型参数 更新所述原 始模型, 得到全局模型。
8.一种模型训练装置, 应用于客户端, 其特 征在于, 所述模型训练装置包括:
原始图像获取模块, 用于获取待训练的原始图像数据, 并获取服务器端发送的原始模
型的原始训练参数; 其中, 所述原始图像数据是无标注数据, 所述原始训练参数包括所述原
始模型的第一原 始网络参数、 第二原 始网络参数、 聚类数量;
模型处理模块, 用于将所述原始图像数据输入所述原始模型, 通过所述原始模型对所
述原始图像数据进行标签预测得到预测标签, 通过所述原始模型对所述原始图像数据进 行
特征提取得到初步图像特 征;
聚类模块, 用于对所述初步图像特征进行投影聚类处理, 得到域标签和域组合标签; 其
中, 域标签和域组合标签的数量均为第一数量, 且所述第一数量 等于所述聚类数量;
距离损失函数构建模块, 用于根据所述 域标签和所述 域组合标签 计算距离损失函数;
参数更新模块, 用于根据预设的交叉熵损失函数和所述距离损失函数更新所述第 一原
始网络参数得到第一目标参数, 根据所述交叉熵损失函数和所述距离损失函数更新所述第
二原始网络参数得到第二 目标参数; 其中, 所述交叉熵损失函数 由所述预测标签进行预先
构建得到;
参数发送模块, 用于将所述第一目标参数和所述第二目标参数发送给所述服务器端;
其中, 所述第一目标参数和所述第二目标参数用于所述服务器端 更新所述原始模型得到全
局模型。
9.一种电子设备, 其特征在于, 所述电子设备包括存储器、 处理器、 存储在所述存储器
上并可在所述处理器上运行的程序以及用于实现所述处理器和所述存储器之间的连接通
信的数据总线, 所述 程序被所述处 理器执行时实现:
如权利要求1至 6任一项所述的模型训练方法的步骤;
或者,
如权利要求7 所述的模型训练方法的步骤。权 利 要 求 书 2/3 页
3
CN 115205546 A
3
专利 模型训练方法和装置、电子设备、存储介质
文档预览
中文文档
26 页
50 下载
1000 浏览
0 评论
309 收藏
3.0分
温馨提示:本文档共26页,可预览 3 页,如浏览全部内容或当前文档出现乱码,可开通会员下载原始文档
本文档由 人生无常 于 2024-03-18 04:42:23上传分享