HSTU Embedding 学习笔记

Torchrec

数据结构

JaggedTensor (jt)

“Jagged” 的意思是 “锯齿状的”,顾名思义,这种 Tensor 每一行的长度不一,可以将长度差异大的特征高效存储为一个 batch,避免 padding 造成资源浪费。如,可以将不同用户的交互历史作为一个 batch,放进 JaggedTensor 中存储。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# 存储用户交互历史:
# - User 1 交互了 2 个物品
# - User 2 交互了 3 个物品
# - User 3 交互了 1 个物品

# lengths 给出每个用户交互历史的长度:
# - User 1 交互历史长度为 2
# - User 2 交互历史长度为 3
# - User 3 交互历史长度为 1
lengths = torch.tensor([2, 3, 1], dtype=torch.int32)

# offsets 给出每个用户交互历史的区间: 
# - User 1 的交互历史在 values[0:2]
# - User 2 的交互历史在 values[2:5]
# - User 3 的交互历史在 values[5:6]
offsets = torch.tensor([0, 2, 5, 6], dtype=torch.int32)  

# 实际存储的是一个一维的 Tensor,通过划分来达到变长的效果
values = torch.Tensor([101, 102, 
                       201, 202, 203, 
                       301])  

jt = JaggedTensor(lengths=lengths, values=values)
print(jt)
jt = JaggedTensor(offsets=offsets, values=values)
print(jt)

values_list = jt.to_dense() # 转为 list of tensors
print(values_list)

outputs:

1
2
3
4
5
6
7
8
9
JaggedTensor({
    [[101.0, 102.0], [201.0, 202.0, 203.0], [301.0]]
})

JaggedTensor({
    [[101.0, 102.0], [201.0, 202.0, 203.0], [301.0]]
})

[tensor([101., 102.]), tensor([201., 202., 203.]), tensor([301.])]

KeyedJaggedTensor (kjt)

可以理解为多个 jt 的集合,同时给每个 jt 分配了一个标签,便于区分和访问。可以用来存储不同类别的的特征。

隐含假设:不同组特征之间是一一对应的,比如一组 user_features 对应一组 item_features,会自动平均划分,所以不需要指定 lengths 或者 offsets 中的元素对应哪组特征。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
keys = ["user_features", "item_features"] # 定义标签,特征包括用户特征和物品特征

# - 用户特征: 3 个用户,交互记录长度分别为 2, 0, 3
# - 物品特征: 3 个物品,特征长度分别为 1, 2, 1
lengths = torch.tensor([2, 0, 3, 1, 2, 1], dtype=torch.int32)
values = torch.Tensor([11, 12, 21, 22, 23, 101, 201, 202, 77])
# Create a KeyedJaggedTensor
kjt = KeyedJaggedTensor(keys=keys, 
                        lengths=lengths, 
                        values=values)

print(kjt)
# 通过 keys 访问不同特征
print(kjt["user_features"])
print(kjt["item_features"])

print('batch size:', kjt.stride())  # 获取批次大小

outputs:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
KeyedJaggedTensor({
    "user_features": [[11.0, 12.0], [], [21.0, 22.0, 23.0]],
    "item_features": [[101.0], [201.0, 202.0], [77.0]]
})

JaggedTensor({
    [[11.0, 12.0], [], [21.0, 22.0, 23.0]]
})

JaggedTensor({
    [[101.0], [201.0, 202.0], [77.0]]
})

batch size: 3

模型

torch.nn.Embedding

这是 torch 中的 Embedding 层,可以用 nn.Embedding(num_embeddings, embedding_dim) 定义一个词表大小为 num_embeddings ,嵌入维度为 embedding_dim 的 Embedding 层。

1
2
3
4
5
# an Embedding module containing 10 tensors of size 3
embedding = nn.Embedding(10, 3) # 初始权重是随机的
# a batch of 2 samples of 4 indices each
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
embedding(input)

outputs:

1
2
3
4
5
6
7
8
9
tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])

EmbeddingCollection (EC)

这个可以看作是 torch.nn.Embedding 的升级版,给每组特征搭上了标签。具体地,传入 kjt ,得到 embedding 。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
e1_config = EmbeddingConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
e2_config = EmbeddingConfig(
    name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"]
)

ec = EmbeddingCollection(tables=[e1_config, e2_config])

#     0       1        2  <-- batch
# 0   [0,1] None    [2]
# 1   [3]    [4]    [5,6,7]
# ^
# feature

features = KeyedJaggedTensor.from_offsets_sync(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1,                  2,    # feature 'f1'
                            3,      4,    5, 6, 7]),  # feature 'f2'
                    #    i = 1    i = 2    i = 3   <--- batch indices
    offsets=torch.tensor([
            0, 2, 2,       # 'f1' bags are values[0:2], values[2:2], and values[2:3]
            3, 4, 5, 8]),  # 'f2' bags are values[3:4], values[4:5], and values[5:8]
)

feature_embeddings = ec(features)
print(feature_embeddings)
print('------------------')
print('### f1 feature')
print(feature_embeddings["f1"].offsets())
print(feature_embeddings["f1"].values())
print('### f2 feature')
print(feature_embeddings["f2"].offsets())
print(feature_embeddings["f2"].values())

outputs:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
{'f1': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7f7eb1ab2150>, 'f2': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7f7eb1ab1d30>}
------------------
### f1 feature
tensor([0, 2, 2, 3])
tensor([[ 0.0548,  0.1086,  0.0119],
        [ 0.0635,  0.1062, -0.2560],
        [-0.2143, -0.0599, -0.2480]])
### f2 feature
tensor([0, 1, 2, 5])
tensor([[ 0.2697,  0.1827, -0.0657],
        [-0.1316, -0.2873,  0.3116],
        [-0.0632, -0.1552, -0.0938],
        [-0.3044,  0.0939,  0.0486],
        [ 0.1156, -0.0385,  0.0146]])

torch.nn.EmbeddingBag

这是 Embedding + pooling 层,pooling 模式有 sum, mean, max, 默认为 mean 。传入一个 batch 后,会查表计算出每个序号的 embedding,然后进行 pooling,不会保存中间状态。

1
2
3
4
5
6
7
8
9
# an EmbeddingBag module containing 10 tensors of size 3
embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
# a batch of 2 samples of 4 indices each
input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
offsets = torch.tensor([0, 4,], dtype=torch.long)
print(embedding_sum(input, offsets))

input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]], dtype=torch.long)
print(embedding_sum(input))

outputs:

1
2
3
4
tensor([[-0.8861, -5.4350, -0.0523],
        [ 1.1306, -2.5798, -1.0044]])
tensor([[-0.8861, -5.4350, -0.0523],
        [ 1.1306, -2.5798, -1.0044]])

这里支持 offsets 参数,使用起始位置来划分每个 batch。所以可以传入 2D 的 inputs ,也可以传入 1D 的 inputs + offsets

EmbeddingBagCollection (EBC)

这个可以看作是 torch.nn.EmbeddingBag 的升级版。可以根据 featrue_names 对特征进行池化。具体地,传入 kjt ,前向传播得到池化后的特征。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
table_0 = EmbeddingBagConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
table_1 = EmbeddingBagConfig(
    name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
)

ebc = EmbeddingBagCollection(tables=[table_0, table_1])

#        i = 0     i = 1    i = 2  <-- batch indices
# "f1"   [0,1]     None      [2]
# "f2"   [3]       [4]     [5,6,7]
#  ^
# features

features = KeyedJaggedTensor(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1,                  2,    # feature 'f1'
                            3,      4,    5, 6, 7]),  # feature 'f2'
                    #    i = 1    i = 2    i = 3   <--- batch indices
    offsets=torch.tensor([
            0, 2, 2,       # 'f1' bags are values[0:2], values[2:2], and values[2:3]
            3, 4, 5, 8]),  # 'f2' bags are values[3:4], values[4:5], and values[5:8]
)

pooled_embeddings = ebc(features)
print(pooled_embeddings)
print(pooled_embeddings.values())
print(pooled_embeddings.keys())
print(pooled_embeddings.offset_per_key())

outputs:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
KeyedTensor({
    "f1": [[-0.05368266999721527, -0.49714457988739014, 0.027471214532852173], [0.0, 0.0, 0.0], [-0.09605476260185242, -0.30402711033821106, -0.013938307762145996]],
    "f2": [[-0.20428268611431122, -0.25504255294799805, 0.23750828206539154, 0.3035320043563843], [0.057708702981472015, 0.14746926724910736, 0.05408177152276039, -0.2895013391971588], [-0.010172609239816666, -0.01704445481300354, -0.07303650677204132, 0.3184009790420532]]
})

tensor([[-0.0537, -0.4971,  0.0275, -0.2043, -0.2550,  0.2375,  0.3035],
        [ 0.0000,  0.0000,  0.0000,  0.0577,  0.1475,  0.0541, -0.2895],
        [-0.0961, -0.3040, -0.0139, -0.0102, -0.0170, -0.0730,  0.3184]],
       grad_fn=<CatBackward0>)
['f1', 'f2']
[0, 3, 7]

HSTU Embedding

embedding-preprocess