Skip to content

Commit 907ec48

Browse files
kausvfacebook-github-bot
authored andcommitted
OSS Hash MC Modules (pytorch#2797)
Summary: Pull Request resolved: pytorch#2797 This diff moves the Hash MC module to OSS directory Reviewed By: zlzhao1104 Differential Revision: D70970107 fbshipit-source-id: bba24ab79654f2a1f2c6c27aef3d6159df314d4b
1 parent f6bc1b2 commit 907ec48

File tree

6 files changed

+2117
-0
lines changed

6 files changed

+2117
-0
lines changed

torchrec/modules/hash_mc_evictions.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
#!/usr/bin/env python3
11+
12+
import logging
13+
import time
14+
from dataclasses import dataclass
15+
from enum import Enum, unique
16+
from typing import List, Optional, Tuple
17+
18+
import torch
19+
from pyre_extensions import none_throws
20+
21+
from torchrec.sparse.jagged_tensor import JaggedTensor
22+
23+
logger: logging.Logger = logging.getLogger(__name__)
24+
25+
26+
@unique
27+
class HashZchEvictionPolicyName(Enum):
28+
# eviction based on the time the ID is last seen during training,
29+
# and a single TTL
30+
SINGLE_TTL_EVICTION = "SINGLE_TTL_EVICTION"
31+
# eviction based on the time the ID is last seen during training,
32+
# and per-feature TTLs
33+
PER_FEATURE_TTL_EVICTION = "PER_FEATURE_TTL_EVICTION"
34+
# eviction based on least recently seen ID within the probe range
35+
LRU_EVICTION = "LRU_EVICTION"
36+
37+
38+
@torch.jit.script
39+
@dataclass
40+
class HashZchEvictionConfig:
41+
features: List[str]
42+
single_ttl: Optional[int] = None
43+
per_feature_ttl: Optional[List[int]] = None
44+
45+
46+
@torch.fx.wrap
47+
def get_kernel_from_policy(
48+
policy_name: Optional[HashZchEvictionPolicyName],
49+
) -> int:
50+
return (
51+
1
52+
if policy_name is not None
53+
and policy_name == HashZchEvictionPolicyName.LRU_EVICTION
54+
else 0
55+
)
56+
57+
58+
class HashZchEvictionScorer:
59+
def __init__(self, config: HashZchEvictionConfig) -> None:
60+
self._config: HashZchEvictionConfig = config
61+
62+
def gen_score(self, feature: JaggedTensor) -> torch.Tensor:
63+
return torch.empty(0, device="cuda")
64+
65+
def gen_threshold(self) -> int:
66+
return -1
67+
68+
69+
class HashZchSingleTtlScorer(HashZchEvictionScorer):
70+
def gen_score(self, feature: JaggedTensor) -> torch.Tensor:
71+
assert (
72+
self._config.single_ttl is not None
73+
), "To use scorer HashZchSingleTtlScorer, single_ttl is required."
74+
75+
return torch.full_like(
76+
feature.values(),
77+
# pyre-ignore [58]
78+
self._config.single_ttl + int(time.time() / 3600),
79+
dtype=torch.int32,
80+
device="cuda",
81+
)
82+
83+
def gen_threshold(self) -> int:
84+
return int(time.time() / 3600)
85+
86+
87+
class HashZchPerFeatureTtlScorer(HashZchEvictionScorer):
88+
def __init__(self, config: HashZchEvictionConfig) -> None:
89+
super().__init__(config)
90+
91+
assert self._config.per_feature_ttl is not None and len(
92+
self._config.features
93+
) == len(
94+
# pyre-ignore [6]
95+
self._config.per_feature_ttl
96+
), "To use scorer HashZchPerFeatureTtlScorer, a 1:1 mapping between features and per_feature_ttl is required."
97+
98+
self._per_feature_ttl = torch.IntTensor(self._config.per_feature_ttl)
99+
100+
def gen_score(self, feature: JaggedTensor) -> torch.Tensor:
101+
feature_split = feature.weights()
102+
assert feature_split.size(0) == self._per_feature_ttl.size(0)
103+
104+
scores = self._per_feature_ttl.repeat_interleave(feature_split) + int(
105+
time.time() / 3600
106+
)
107+
108+
return scores.to(device="cuda")
109+
110+
def gen_threshold(self) -> int:
111+
return int(time.time() / 3600)
112+
113+
114+
@torch.fx.wrap
115+
def get_eviction_scorer(
116+
policy_name: str, config: HashZchEvictionConfig
117+
) -> HashZchEvictionScorer:
118+
if policy_name == HashZchEvictionPolicyName.SINGLE_TTL_EVICTION:
119+
return HashZchSingleTtlScorer(config)
120+
elif policy_name == HashZchEvictionPolicyName.PER_FEATURE_TTL_EVICTION:
121+
return HashZchPerFeatureTtlScorer(config)
122+
elif policy_name == HashZchEvictionPolicyName.LRU_EVICTION:
123+
return HashZchSingleTtlScorer(config)
124+
else:
125+
return HashZchEvictionScorer(config)
126+
127+
128+
class HashZchThresholdEvictionModule(torch.nn.Module):
129+
"""
130+
This module manages the computation of eviction score for input IDs. Based on the selected
131+
eviction policy, a scorer is initiated to generate a score for each ID. The kernel
132+
will use this score to make eviction decisions.
133+
134+
Args:
135+
policy_name: an enum value that indicates the eviction policy to use.
136+
config: a config that contains information needed to run the eviction policy.
137+
138+
Example::
139+
module = HashZchThresholdEvictionModule(...)
140+
score = module(feature)
141+
"""
142+
143+
_eviction_scorer: HashZchEvictionScorer
144+
145+
def __init__(
146+
self,
147+
policy_name: HashZchEvictionPolicyName,
148+
config: HashZchEvictionConfig,
149+
) -> None:
150+
super().__init__()
151+
152+
self._policy_name: HashZchEvictionPolicyName = policy_name
153+
self._config: HashZchEvictionConfig = config
154+
self._eviction_scorer = get_eviction_scorer(
155+
policy_name=self._policy_name,
156+
config=self._config,
157+
)
158+
159+
logger.info(
160+
f"HashZchThresholdEvictionModule: {self._policy_name=}, {self._config=}"
161+
)
162+
163+
def forward(self, feature: JaggedTensor) -> Tuple[torch.Tensor, int]:
164+
"""
165+
Args:
166+
feature: a jagged tensor that contains the input IDs, and their lengths and
167+
weights (feature split).
168+
169+
Returns:
170+
a tensor that contains the eviction score for each ID, plus an eviction threshold.
171+
"""
172+
return (
173+
self._eviction_scorer.gen_score(feature),
174+
self._eviction_scorer.gen_threshold(),
175+
)
176+
177+
178+
class HashZchOptEvictionModule(torch.nn.Module):
179+
"""
180+
This module manages the eviction of IDs from the ZCH table based on the selected eviction policy.
181+
Args:
182+
policy_name: an enum value that indicates the eviction policy to use.
183+
Example:
184+
module = HashZchOptEvictionModule(policy_name=HashZchEvictionPolicyName.LRU_EVICTION)
185+
"""
186+
187+
def __init__(
188+
self,
189+
policy_name: HashZchEvictionPolicyName,
190+
) -> None:
191+
super().__init__()
192+
193+
self._policy_name: HashZchEvictionPolicyName = policy_name
194+
195+
def forward(self, feature: JaggedTensor) -> Tuple[None, int]:
196+
"""
197+
Does not apply to this Eviction Policy. Returns None and -1.
198+
Args:
199+
feature: No op
200+
Returns:
201+
None, -1
202+
"""
203+
return None, -1
204+
205+
206+
@torch.fx.wrap
207+
def get_eviction_module(
208+
policy_name: HashZchEvictionPolicyName, config: Optional[HashZchEvictionConfig]
209+
) -> torch.nn.Module:
210+
if policy_name in (
211+
HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
212+
HashZchEvictionPolicyName.PER_FEATURE_TTL_EVICTION,
213+
HashZchEvictionPolicyName.LRU_EVICTION,
214+
):
215+
return HashZchThresholdEvictionModule(policy_name, none_throws(config))
216+
else:
217+
return HashZchOptEvictionModule(policy_name)
218+
219+
220+
class HashZchEvictionModule(torch.nn.Module):
221+
"""
222+
This module manages the eviction of IDs from the ZCH table based on the selected eviction policy.
223+
Args:
224+
policy_name: an enum value that indicates the eviction policy to use.
225+
config: an optional config required if threshold based eviction is selected.
226+
Example:
227+
module = HashZchEvictionModule(policy_name=HashZchEvictionPolicyName.LRU_EVICTION)
228+
"""
229+
230+
def __init__(
231+
self,
232+
policy_name: HashZchEvictionPolicyName,
233+
config: Optional[HashZchEvictionConfig],
234+
) -> None:
235+
super().__init__()
236+
237+
self._policy_name: HashZchEvictionPolicyName = policy_name
238+
self._eviction_module: torch.nn.Module = get_eviction_module(
239+
self._policy_name, config
240+
)
241+
242+
logger.info(f"HashZchEvictionModule: {self._policy_name=}")
243+
244+
def forward(self, feature: JaggedTensor) -> Tuple[Optional[torch.Tensor], int]:
245+
"""
246+
Args:
247+
feature: a jagged tensor that contains the input IDs, and their lengths and
248+
weights (feature split).
249+
250+
Returns:
251+
For threshold eviction, a tensor that contains the eviction score for each ID, plus an eviction threshold. Otherwise None and -1.
252+
"""
253+
return self._eviction_module(feature)

0 commit comments

Comments
 (0)