共计 1134 个字符,预计需要花费 3 分钟才能阅读完成。
前言
Embedding table 优化关键的一点是内存空间占用的优化,比如一个id类特征几个亿,维度32维,你要生成几亿*32的矩阵,这个存储空间消耗可是很大,所以今天介绍其中一个方法,就是 muti hash 的方法,这个是在最近阿里开源的deeprec框架上面,不知道这个框架能不能持续发展下去,看看xdl就很心寒,其他家推出的开源框架基本是只是上传了出版代码后面的就没了。主要大模型这块,不光是训练框架,其他的涉及推理等等都是配套的,这块也算是技术壁垒,毕竟各家都是自研的框架。
今天介绍的 multi hash 的方法就是要解决embedding 维度过大,存储空间占用过多的问题,论文可以参考《Compositional Embeddings Using Complementary Partitions for Memory-Efficent Recommendation Systems》。
先来看下传统的 embedding table
如果id 特征有 N 种取值情况,那么你就要生成 N*Embedding_size大小的矩阵用于存储向量,这个消耗的内存过于庞大,毕竟id特征还不少。当然最主要的一般都是用户的id 和物料的id ,动辄上亿维度。
multi hash 的方法就是要减少 N 的维度。既然方法里面提到hash,这也是这个方法的重点。
算法逻辑
QUOTIENT-REMAINDER TRICK
算法的详细计算如下图所示:
- 构建两个矩阵,两个矩阵是互补的关系,注意两个的维度大小一个是m\times D 一个是\frac{\vert S\vert}{m}\times D
- 计算查询的索引i对应在两个embedding table中的索引
- 获取到对应的embedding数据
- 对从W_1和W_2查到的向量做elemen-wise计算
泛化的QR方法
这里需要强调的是构建多个 Partitions 也就是分成多少组,对应多少个 Embedding table。
举个例子:
S = {0, 1, 2, 3, 4}. 分成3个Partition如下所示
{{0}, {1, 3, 4}, {2}}, {{0, 1, 3}, {2, 4}}, {{0, 3}, {1, 2, 4}}.
到这一步之后下一步还是要结合多个 Partition下查询出来的embedding 结合问题,这里文章给出了三种方式:concat 、addition和element-wise 乘积
实验结论
Full table可以看作是ground truth标准,可以看到 qr trick 相比于 hash trick效果好,这个就是hash trick 冲突带来的问题,qr trick 一定程度上是做了缓解,就是hash trick 和full table之间的折中方案,算是效果和工程实现兼顾,对于工业界来说能落地才是最重要的环节。