现代推荐系统每天产生数十亿条行为日志,每次模型架构迭代都要在全量历史数据上重训——代价极高。DIET 提出「流式数据集蒸馏」,维护一个仅占原数据 1-2% 的合成记忆集,且随流式数据持续更新,让新模型在这个小数据集上训练就能复现全量训练的行为,最多降低 60× 迭代成本。
论文链接:https://arxiv.org/abs/2603.24958
关键词:Dataset Distillation · Continual Learning · Recommender Systems · Influence Function · Bi-level Optimization
1.1 问题:模型迭代越来越贵
工业推荐系统在持续学习(Continual Learning)范式下运行:模型参数跨时间段继承更新,每次新模型架构上线前需要在全量历史数据上重新训练,以获得可靠的性能验证。随着模型越来越大,这个代价呈指数级增长:
| 模型 | 参数量 | 推理 FLOPs/样本 | 6个月数据训练成本(GPU 时) |
|---|---|---|---|
| DLRM | 8.7M | 52 G | ~9.7 万 |
| Wukong | 122M | 442 G | ~590 万 |
| RankMixer | 1B | 2106 G | ~6000 万 |
注:训练成本估算基于 RTX 4090 单卡,$C = D \times F$,$T_{\text{hour}} = C / (\eta \cdot P \cdot 3600)$。
1.2 现有方案的失败
方案 A:窗口采样(Window Sampling)
只用最近一段时间的数据。
问题:丢失历史行为模式,模型优化轨迹和全量训练完全不同。
方案 B:随机采样(Random Selection)
随机从历史数据中抽取一个子集。
问题:Figure 1 实验显示与全量训练结果的相关性仅 r=0.319,几乎不可信。
1.3 Dataset Distillation vs Coreset Selection
| 方法 | 做法 | 问题 |
|---|---|---|
| Coreset Selection | 从原始数据中挑选「代表性」的真实样本子集 | 保留的是真实样本,但实验表明在深度学习场景下效果接近随机采样,无法保留训练动态 |
| Dataset Distillation (现有方法) |
通过优化合成一批假样本,让在假样本上训练的效果 ≈ 全量数据 | 现有方法都是静态的(假设数据集固定),无法处理流式持续更新的推荐场景 |
| DIET(本文) | 流式数据集蒸馏:蒸馏出的合成数据集随流式数据持续更新,维持紧凑 | — |
1.4 形式化定义
推荐场景下每条交互记录为 $(x, y)$,$x$ 是多域稀疏特征,$y \in \{0,1\}^K$ 是 $K$ 个任务的标签。全量训练目标:
- $e_i = \text{Embed}(x_i)$:embedding 层将稀疏特征映射为稠密向量
- $f_\theta(\cdot)$:预测模型,输出 $K$ 维预测向量
- $\ell(\cdot)$:BCE(二元交叉熵)loss
- $\sigma(\cdot)$:sigmoid 函数,逐元素应用
在持续学习范式下,将全量历史切分为有序数据块 $\mathcal{D}_{1:T} = \{\mathcal{D}_1, \mathcal{D}_2, \ldots, \mathcal{D}_T\}$,区分两种模型角色:
Reference Model $\phi$
持续演化的参考模型,每个阶段用新数据块更新参数:$\phi_t = \text{Update}(\phi_{t-1}, \mathcal{D}_t)$。它「见过」全量历史数据。
Candidate Model $\theta$
待验证的新架构或新训练策略,只在某个特定阶段介入,没有历史数据。挑战是:让它不训练历史数据也能「追上」参考模型的效果。
流式数据集蒸馏目标:构造一系列紧凑合成数据集 $\mathcal{D}^{syn}_{1:T}$,使在其上训练的 candidate model 能近似在全量历史数据上训练的行为,同时合成数据集随流式数据持续更新。
在 bi-level 优化框架下,每个阶段 $t$ 的形式化目标为:
- 内层优化(inner loop):$\phi_t(\mathcal{D}^{syn}_t) = \text{Opt}^K(\phi_{t-1}, \mathcal{D}^{syn}_t)$,在合成数据上跑 $K$ 步梯度更新得到代理模型
- 外层优化(outer loop):用真实数据 $\mathcal{D}_t$ 上的 loss 来评估并更新合成数据集
- $\mathcal{L}_{real}$:代理模型在真实数据上的 loss,越小说明合成数据「模拟」全量数据越准
DIET 框架在每个时间阶段分两个 Phase 运作:Phase 1 初始化并维护合成数据的「边界记忆」,Phase 2 通过影响力引导的记忆寻址对合成数据做精细优化。
2.1 Phase 1:边界记忆的初始化与维护
2.1.1 决策边界样本初始化(EL2N 选样)
推荐系统的多任务数据中,不同标签组合(如「点击+购买」「仅点击」「均未点击」)对模型决策边界的影响力差异巨大,且高频行为会淹没低频但关键的样本。DIET 采用 label-conditioned EL2N 评分:对每个 label 配置 $\mathbf{c} \in \{0,1\}^K$,在其子集内单独排序,选出影响力居中偏高的样本("upper-middle range",避免极端噪声):
- $\phi_t$:当前阶段的参考模型(已见过 $\mathcal{D}_{1:t}$)
- $\text{EL2N}$ 越大:模型对该样本的预测误差越大,说明该样本处于当前决策边界附近,训练影响力强
- 选 "upper-middle range":不选最高分(可能是噪声样本),不选最低分(已学好,冗余)
从所有 label 组合选出的样本汇聚为边界候选集 $\mathcal{S}_t = \bigcup_{\mathbf{c}} \mathcal{S}^{\mathbf{c}}_t$,确保跨行为类型的均衡覆盖。
关键设计:不直接存储原始样本 $(x, y)$,而是转为 embedding + soft label 形式:
- $\tilde{y}$ 是参考模型的输出 logit,包含比 one-hot 标签更丰富的「排名信息」(类似知识蒸馏的 soft target)
- embedding $e$ 作为连续可微参数,后续可以被梯度更新优化,不再受原始离散特征约束
- 存储空间与原始稀疏特征维度无关,只和 embedding 维度 $d$ 相关
假设有 3 个预测任务:点击(CTR)、点赞(Like)、关注(Follow),则 $K=3$,标签空间共 8 种组合,如 $(1,0,0)$(只点击)、$(1,1,0)$(点击+点赞)等。
直接从全量数据按 EL2N 排序选 1000 条,「只点击」这种高频行为会占据大部分名额,「点击+点赞+关注」这种稀有但高价值的行为几乎选不到。DIET 的做法是:对 8 种 label 组合分别独立排序,各自选 Top-K,保证每种行为模式都有代表——这些稀有样本对模型拟合「边界附近」的行为至关重要。
2.1.2 持续合成记忆(历史对齐 + 融合)
只用当前数据块初始化合成数据,会遗漏历史积累的决策边界信息。DIET 将合成数据建模为「持续演化的边界记忆」:对历史合成样本 $(e_i, \tilde{y}_i) \in \mathcal{D}^{syn}_{1:t-1}$,估计其与当前新选边界样本集 $\mathcal{S}_t$ 的对齐程度:
$\alpha_i$ 大的历史样本表示它依然与当前决策边界对齐(未「过时」),保留为 $\tilde{\mathcal{D}}^{syn}_{1:t-1}$。最终本阶段合成记忆由「当前新边界样本」和「对齐的历史记忆」融合而成:
推荐系统的决策边界是随时间演化的——新阶段的用户行为是在历史基础上叠加的。若每次丢弃历史合成记忆,合成数据只能反映当前阶段的局部边界快照,而对「历史中的关键长尾用户」失去记忆。对齐估计函数 $\mathcal{A}$ 保证了只保留仍然有效的历史信息,不引入过时的噪声。
2.2 Phase 2:基于影响力寻址的合成数据优化
光有好的初始化还不够,合成数据本身作为可微参数需要通过 bi-level 优化来精细调整。但如果对所有合成样本都做更新,会引入冗余甚至冲突的信号。DIET 引入影响力引导的双向记忆寻址,选择最值得更新的「合成-真实」配对。
2.2.1 影响力评分
对优化步骤 $t$(模型参数 $w_t$),用训练样本子集 $S$ 更新一步后在验证样本 $z$ 上的 loss 变化作为效用:
经一阶 Taylor 展开化简后,每个样本的影响力(Shapley 值近似)正比于梯度内积,定义影响力评分为:
- $\mathcal{A}_t(x, z)$ 越大:用样本 $x$ 训练,在验证样本 $z$ 上 loss 下降越多,说明 $x$ 对 $z$ 有正向影响
- 这是对 Influence Function 的高效近似,避免了计算 Hessian 逆矩阵的高昂代价
2.2.2 双向记忆寻址
选难目标(Hard Real Targets $\mathcal{B}^{hard}_t$):对每个真实训练样本 $z \in \mathcal{D}_t$,计算当前合成记忆对它的总对齐程度(Deficiency):
Deficiency 最小的样本是合成记忆「照顾最少」的难样本,选这些作为外层优化的锚点 $\mathcal{B}^{hard}_t$。
选活跃合成单元(Active Memory $\mathcal{M}^{active}_t$):对每个合成样本 $x \in \mathcal{D}^{syn}_t$,计算它对难样本集 $\mathcal{B}^{hard}_t$ 的更新责任(Responsibility):
Responsibility 最大的合成样本是最有潜力减小难样本 loss 的「活跃单元」,只对这些单元做更新,其他合成样本保持不动。
假设合成记忆有 100 条样本,当前数据块有 1 万条真实样本。模型刚训练完第 2 阶段,对「三件套商品用户」的行为预测还很差(loss 高)。
Deficiency 选出:「三件套商品用户的交互」这批真实样本的梯度与现有合成记忆的梯度对齐最差——它们是 $\mathcal{B}^{hard}_t$。
Responsibility 选出:合成记忆中,「与三件套用户梯度方向最接近的那 20 条合成样本」——它们是 $\mathcal{M}^{active}_t$,只更新这 20 条,其他 80 条不动。这既保证了更新有针对性,也避免了修改已经学好的历史边界信息。
2.2.3 Bi-level 优化框架(内层 + 外层 + RaT-BPTT)
代理模型 $\theta$ 初始化自参考模型 $\phi_{t-1}$($t=1$ 时随机初始化),在活跃合成记忆 $\mathcal{M}^{active}_t$ 上跑 $M$ 步梯度更新,得到模拟轨迹 $\{\theta^0, \ldots, \theta^M\}$:
用难样本集 $\mathcal{B}^{hard}_t$ 上的 loss 作为 meta 目标:$\mathcal{L}_{meta} = \mathcal{L}_{real}(\theta^M;\, \mathcal{B}^{hard}_t)$,通过内层轨迹反向传播来更新合成数据 $\mathcal{D}^{syn}_t$。
完整 $M$ 步展开的反向传播显存代价是 $O(M \cdot |\theta|)$,RaT-BPTT(Random Truncated BPTT)只保留 $K_{rat} \ll M$ 步的随机窗口做反向传播,大幅降低显存和计算量,同时实验表明对收敛质量影响有限。
2.3 下游模型训练(Warmup-style)
完成历史阶段蒸馏后,得到紧凑合成数据集 $\mathcal{D}^{syn}_{1:T}$。候选模型的训练分两步:
Step 1:Dense 参数从合成数据热身
用蒸馏数据训练候选模型的 dense 网络参数(MLP 层等),让其从合成数据中继承历史训练动态。
Step 2:Embedding 直接继承参考模型
候选模型的 sparse embedding 参数直接从参考模型 $\phi_T$ 的 checkpoint 中拷贝,保留参考模型对用户和物品的表示知识,无需重训 embedding。
随后在后续数据块上做正常的持续训练(fine-tune)即可。
场景:快手电商推荐,每天新增 1 亿条交互,已积累 3 个月历史($T=3$ 个阶段),现在要验证一个新架构 WuKong。
Phase 1:用现有 DCN 参考模型,对每天数据块用 EL2N 选出约 150 万条边界样本(约 1.5%),转为 embedding+soft label,与前两阶段保留的对齐历史记忆融合,得到 $\mathcal{D}^{syn}_t$(总共约 3×150万=450万,再经压缩约 50-100万条)。
Phase 2:用影响力寻址找出最难真实样本 $\mathcal{B}^{hard}$ 和最活跃合成单元 $\mathcal{M}^{active}$,做 bi-level 优化 + RaT-BPTT,精细调整合成数据使其更好还原全量训练动态。
验证新架构:WuKong 用上述 ~100 万条合成数据做 warmup(dense 参数),embedding 从参考模型 $\phi_3$ 继承,再在当天新数据上 fine-tune——全程不需要访问 3 个月全量历史!实验显示 WuKong 在 Tmall 上的 AUC 达 0.7617,接近全量训练的 0.7629。
3.1 实验设置
数据集
- KuaiRand:快手真实推荐数据,多任务(CTR + Like + Follow),约 1M 交互
- Tmall:阿里巴巴电商,点击 + 购买两任务
- Taobao:阿里大规模电商数据
评估协议
- 按时间切分 $T$ 个连续阶段(模拟真实流式场景)
- 指标:AUC(越高越好),与全量数据训练结果的 Pearson 相关系数 $r$
- Reference Model(教师):DCN-v2
- Candidate Models(下游验证):DCN-v2、WuKong、RankMixer
| 方法 | 类型 | KuaiRand 相关性 $r$ | Tmall 相关性 $r$ | 说明 |
|---|---|---|---|---|
| Random Selection | 子集选择 | 0.319 | 0.287 | 与全量训练几乎不相关 |
| Window Sampling | 时间窗口 | 0.681 | 0.598 | 有改善但偏差仍大 |
| K-Means Selection | 聚类子集 | 0.721 | 0.644 | 代表性更好但静态 |
| EL2N Selection | 重要性子集 | 0.743 | 0.672 | 边界感知但无蒸馏 |
| DIET(本文) | 流式数据集蒸馏 | 0.912 | 0.887 | 大幅领先所有基线 |
| Full Data(上界) | 全量训练 | 1.000 | 1.000 | 理论上界 |
3.2 跨架构泛化
用 DCN-v2 作为 Reference Model 蒸馏出的合成数据,是否能有效训练不同架构的 Candidate Model?论文分别在 WuKong 和 RankMixer 上验证:
| Candidate Model | 压缩比 | KuaiRand AUC(DIET) | KuaiRand AUC(Full Data) | 差距 |
|---|---|---|---|---|
| WuKong | 0.3% | 0.7601 | 0.7614 | -0.0013 |
| WuKong | 1.0% | 0.7612 | 0.7614 | -0.0002 |
| WuKong | 2.0% | 0.7615 | 0.7614 | +0.0001 |
| RankMixer | 1.0% | 0.7589 | 0.7607 | -0.0018 |
关键发现:1% 数据即可将 WuKong 的 AUC 差距缩小到 0.0002 以内,跨架构效果同样成立——DCN 蒸馏出的合成数据对 WuKong 和 RankMixer 同样有效,说明合成数据捕捉到了架构无关的训练动态。
3.3 消融实验
| 消融变体 | 移除的组件 | 相关性 $r$(KuaiRand) |
|---|---|---|
| DIET-Full | —(完整方法) | 0.912 |
| w/o 持续记忆融合 | 每阶段从头初始化合成数据 | 0.841 |
| w/o 影响力寻址 | 随机选样参与 bi-level 更新 | 0.873 |
| w/o RaT-BPTT | 完整 M 步反向传播(慢 ~3×) | 0.907 |
| Hard Label 替代 Soft Label | 用 one-hot 代替软化 logit | 0.869 |
两个最关键的组件是持续记忆融合(r 下降 0.071)和影响力寻址(r 下降 0.039)。RaT-BPTT 是效率近似,对质量影响有限,但省去 ~3× 的计算开销。
3.4 效率分析
蒸馏开销是一次性的(参考模型训练一次即可),后续所有 Candidate Model 的验证都复用同一批合成数据。设蒸馏开销为 $C_{distill}$,验证 $N$ 个模型所节省的总成本:
当 $N \geq 5$ 时节省率即超过 85%;论文引用的实际案例中,WuKong 架构验证成本降低 60×。
4.1 亮点
现有数据集蒸馏方法几乎全都是静态场景(ImageNet 子集蒸馏、NLP benchmark 压缩),DIET 是第一个把这个范式带进「流式、时序演化、大规模工业推荐」场景的工作,问题定义本身就是贡献。
DCN 蒸馏出的合成数据,拿来训 WuKong 和 RankMixer 依然有效,说明这批数据捕捉到了「哪些样本是高影响力的训练信号」这个本质,而非 DCN 特有的归纳偏置。这对工业实践意义很大:参考模型可以用成熟的轻量架构,候选模型可以是任意新架构。
传统影响力函数需要计算 Hessian 逆矩阵(参数量大时计算量不可承受),DIET 用一阶 Taylor 近似(梯度内积)替代,计算量降到 $O(d)$,在工业推荐($d$ 通常数十万到百万)中可行。
4.2 局限与可能的改进方向
⚠️ 局限:参考模型依赖
合成数据的质量强依赖参考模型($\phi$)的能力。若参考模型本身学偏了(如过拟合某种数据偏置),合成数据也会继承这种偏置。论文指出「参考模型容量比架构同质性更关键」,但如何系统处理参考模型的偏置还是开放问题。
⚠️ 局限:embedding 参数继承的局限
下游模型直接继承参考模型的 embedding,适用于架构相同(embedding 维度一致)的场景,但跨维度迁移(如参考模型 256d,候选模型 512d)需要额外适配,论文未充分讨论。
💡 可能的改进:在线蒸馏
当前 Phase 2 的 bi-level 优化仍然需要离线跑,能否做成「每来一批数据就增量更新合成记忆」的在线蒸馏,使延迟更低,是一个有趣的方向。
💡 可能的改进:结合 diffusion/生成模型
当前合成数据是以 embedding 向量形式存储的,能否用条件生成模型(如 diffusion)来参数化「合成数据生成过程」,使合成数据可以随时根据当前数据分布重新采样,是更长期的研究方向。