Transformer 训练里,GEMM 往往是主战场。但一个反常点是:GEMM 越快,GEMM 周围那些“小算子”的显存搬运越刺眼。
2026 年 5 月发布的 arXiv 论文《CODA: Rewriting Transformer Blocks as GEMM-Epilogue Programs》,编号 arXiv:2605.19269,盯上的就是这件事。CODA 的核心抽象不是重做训练框架,而是把大量非注意力计算写成 GEMM-plus-epilogue 程序。
我更在意的是它提出的问题:能不能少写一堆一次性的融合 kernel,同时保留接近专家 GEMM 的性能结构?如果答案成立,受影响最大的不是普通模型用户,而是 GPU kernel 开发者和训练平台团队。
瓶颈不只在 GEMM,也在它旁边的显存读写
标准 Transformer block 里,线性层当然重要。cuBLAS、CUTLASS、Triton 和手写 kernel,也已经把 GEMM 打磨得很深。
但训练图里不只有矩阵乘法。normalization、activation、residual update、reduction 这类操作,FLOPs 不一定大,却会反复读写中间张量。计算看起来轻,带宽账单不轻。
这也是 CODA 的切入口。它处理的不是注意力矩阵那块问题,也不是通信、优化器状态、并行调度这些系统问题。它主要面向标准 Transformer block 中非注意力部分的前向和反向计算。
这个边界很关键。论文里提到覆盖 nearly all non-attention computation,不能读成覆盖整个 Transformer 训练。FlashAttention 解决的是注意力里的显存落地问题;CODA 处理的是注意力之外,GEMM 周边的零碎搬运。
对训练平台团队来说,这个判断会影响投入方式。它不该成为“马上迁移训练栈”的理由,更适合作为 kernel 层面的专项验证:选几个高频 block,看搬运减少后,端到端训练 step 是否真的受益。
CODA 的方法:固定 GEMM mainloop,开放可组合 epilogue
CODA 的做法很克制:固定 GEMM mainloop,把周边计算放进可组合的 epilogue primitives。输出 tile 还在片上时,尽量完成 scaling、reduction、pairwise transformation、accumulation 等操作,减少中间张量写回 HBM 再读出的次数。
这条路线的重点不是“融合”两个字。手写融合 kernel 也能做融合,甚至性能上限很高。问题是每遇到一个 shape、一个 block 变体、一个反向路径,就可能多一份维护成本。
| 路线 | 做法 | 对工程团队的好处 | 现实约束 |
|---|---|---|---|
| 普通框架算子 | 每个算子相对独立执行 | 通用,接入成本低 | 中间张量读写多 |
| 手写融合 kernel | 针对固定模式深度定制 | 单点性能上限高 | 复用难,调试和维护贵 |
| CODA | GEMM mainloop + 可组合 epilogue | 借用 GEMM 性能结构,提升代码复用 | 主要面向标准 Transformer block 的非注意力前后向 |
论文声称,在代表性 Transformer workload 上,人类和 LLM 编写的 CODA kernels 都能达到较高性能。这里不要读成“LLM 已经能随手写高性能 CUDA”。更稳妥的理解是:受约束的抽象,可能让高性能 kernel 生成更容易落在可控范围内。
这对 GPU kernel 开发者有直接影响。与其一开始就为每个模型结构写专用融合 kernel,不如先把 normalization、activation、residual update、reduction 这些高频模式映射到 CODA 风格的 epilogue 原语上,观察表达力够不够。
如果表达力够,开发者少维护几套近似重复的 kernel。如果表达力不够,CODA 就会退回论文 workload 里的漂亮抽象,难以覆盖真实训练代码里的枝杈。
值不值得试,先看三件事
目前只能确认两点。CODA 给出了清晰的 GEMM-plus-epilogue 抽象;论文称人类和 LLM 编写的 CODA kernels 在代表性 Transformer workload 上达到较高性能。
看不清的地方也不少。原始线索里没有具体加速倍数、显存节省比例、硬件型号和完整 workload 细节支撑,不能替它补数字。也不能默认它已经被主流框架或厂商采用。
工程团队如果要评估,动作可以很具体。
| 要验证的问题 | 建议动作 | 判断标准 |
|---|---|---|
| 是否真能减少搬运 | 选 1-2 个标准 Transformer block 做 kernel spike | 看训练 step 中相关段落是否变短,而不是只看单 kernel |
| 是否比手写融合省维护 | 对比同一模式下的手写 kernel 和 CODA 写法 | 看 shape 变化、前后向变化时要改多少代码 |
| 是否能进现有栈 | 放到 PyTorch、Triton、CUTLASS 等既有流程旁边评估 | 看调试、自动微分、调度衔接是否拖慢团队 |
这三件事比单个跑分更接近真实决策。训练成本不是论文表格里的一个 kernel 数字,而是开发时间、调试时间、回归风险和硬件利用率加起来的账。
我不太买账的是把 CODA 讲成“下一代训练系统”。证据还不够。它更像一把窄刀,专门削 Transformer block 里非注意力计算的显存搬运。
窄不等于小。大模型训练里,很多成本就藏在这些不起眼的搬运里。小处见功,前提是它能在更多 shape、更多模型变体和真实训练栈里继续守住性能与可维护性。
