Skip to content

Commit de99b4d

Browse files
emlinfacebook-github-bot
authored andcommitted
add block_bucketize_2d_weights kernel (#4778)
Summary: X-link: facebookresearch/FBGEMM#1801 for new embedding cache feature requirement, we need to distribute id with weight to sharded embedding, and then write the embedding value directly instead of via backward pass and optimizer step. this feature will be able to extend kvzch tbe to be an embedding cache. here has more information about the requirement details https://docs.google.com/document/d/1_wlH1qoNOkK7nQmphCqBATDTHNvFJMclURwO1SDDJsQ/edit?tab=t.0#heading=h.sb19p9vr4aha and how to support it from kvzch: https://docs.google.com/document/d/1TJHKvO1m3-5tYAKZGhacXnGk7iCNAzz7wQlrFbX_LDI/edit?tab=t.0#heading=h.70vdya87lyup due to this requirement, we need a block bucketization with 2d weight tensor kernel to help distribute data in the same way with input dist. to avoid pollute the original input dist kernel, and we do not need the function of pooling, I added a separate kernel to support 2d weight Differential Revision: D78928659
1 parent 1d46cdd commit de99b4d

File tree

7 files changed

+2487
-0
lines changed

7 files changed

+2487
-0
lines changed

fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,3 +496,121 @@
496496
None)
497497
""",
498498
)
499+
500+
add_docs(
501+
torch.ops.fbgemm.block_bucketize_sparse_features_2d_weights,
502+
"""
503+
block_bucketize_sparse_features_2d_weights(lengths, indices, bucketize_pos, sequence, block_sizes, my_size, weights, weights_dim=1, batch_size_per_feature=None, max_B= -1, block_bucketize_pos=None, keep_orig_idx=False, total_num_blocks=None, keep_orig_idx_per_feature=None) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]
504+
505+
Preprocess sparse features by partitioning sparse features into multiple
506+
buckets with support for 2D weights. Every feature is split into the same number of buckets, but the bucket
507+
sizes (widths) for the different features can be different. Moreover, the
508+
bucket sizes within each feature can be different.
509+
510+
This function is similar to block_bucketize_sparse_features but supports 2D weights,
511+
where each index can have multiple weight values associated with it.
512+
513+
Args:
514+
lengths (Tensor): The lengths of the sparse features. The tensor contains
515+
the lengths of each sample in a batch and each feature. Shape is `B *
516+
T` where `B` is the batch size and `T` is the number of features
517+
518+
indices (Tensor): The sparse data. Only support integer types. Shape is the
519+
sum of `lengths`
520+
521+
bucketize_pos (bool): If True, return the original relative indices within
522+
a sample. For example, `indices = [9, 8, 2, 1, 0, 8, 9]` and `lengths =
523+
[3, 4]`. The original relative indices within a sample for the indices
524+
are `[0, 1, 2, 0, 1, 2, 3]`
525+
526+
sequence (bool): If True, return the new indices positions in the original
527+
indices positions (the tensor is called `unbucketize_permute_data`).
528+
529+
block_sizes (Tensor): This tensor is used for the case where the bucket
530+
size within a feature is uniform (i.e., when
531+
`block_bucketize_pos=None`). The tensor contains bucket sizes (i.e.,
532+
bucket widths) for each feature. `block_sizes[t]` represents the
533+
bucket size of feature `t`. Shape is the number of features.
534+
535+
my_size (int): The number of buckets for each feature. Note that every
536+
feature has the same number of buckets.
537+
538+
weights (Tensor): A float tensor that will be bucketized the same way as
539+
`indices`. This tensor must have shape `[indices.size(0), weights_dim]`
540+
where `weights_dim` is the dimension of the weight values for each index.
541+
542+
weights_dim (int = 1): The dimension of the weight values for each index.
543+
This parameter is only used when `weights` is not None.
544+
545+
batch_size_per_feature (Optional[Tensor] = None): An optional tensor that
546+
contains batch sizes for different features. If not None, batch sizes
547+
are not uniform among features. Otherwise, the operator will assume
548+
that the batch size is uniform and infer it from the `lengths` and
549+
`block_sizes` tensors
550+
551+
max_B (int = -1): The max batch size. Must be set if
552+
`batch_size_per_feature` is not None
553+
554+
block_bucketize_pos (Optional[List[Tensor]] = None): The input is used for
555+
non-uniform bucket sizes within a feature. `block_bucketize_pos` is a
556+
list of tensors. Each tensor contains the range offsets of buckets for
557+
each feature. These range offsets are equivalent to the complete
558+
cumulative sum of the bucket sizes. For example, `[0, 4, 20]` represents
559+
two buckets. The first bucket size is `(4 - 0) = 4`, and the second
560+
bucket size is `(20 - 4) = 16`. The length of `block_bucketize_pos`
561+
must be equal to the number of features.
562+
563+
keep_orig_idx (bool = False): If True, return original indices instead of
564+
the relative indices within each bucket
565+
566+
total_num_blocks (Optional[torch.Tensor] = None): An optional tensor that
567+
contains then number of logical buckets (aka blocks) within a given
568+
feature. This is useful for applications where the number of buckets
569+
is more than the number of physical GPUs, which is common in cases
570+
where we scale up/down the number of GPUs but want to maintain
571+
same numerical behavior.
572+
573+
keep_orig_idx_per_feature (Optional[Tensor] = None): An optional tensor that
574+
contains whether to keep original indices for each feature. If not None,
575+
the operator will use this tensor to determine whether to keep original
576+
indices for each feature. if None, will fallback to `keep_orig_idx`
577+
578+
Return:
579+
A tuple of tensors containing
580+
581+
(1) Bucketized lengths. Shape is `lengths.num() * my_size`.
582+
583+
(2) Bucketized indices. Same shape as `indices`.
584+
585+
(3) Bucketized weights or None if `weights` is None. Shape is
586+
`[indices.size(0), weights_dim]`.
587+
588+
(4) Bucketized positions or None if `bucketize_pos=False`. Same shape as
589+
`indices`.
590+
591+
(5) `unbucketize_permute` or None if `sequence=False`. Same shape as
592+
`indices`
593+
594+
**Example**:
595+
596+
>>> # Generate input example. Batch size = 2. Number of features = 4
597+
>>> lengths = torch.tensor([0, 2, 1, 3, 2, 3, 3, 1], dtype=torch.int, device="cuda")
598+
>>> indices = torch.tensor([3, 4, 15, 11, 28, 29, 1, 10, 11, 12, 13, 11, 22, 20, 20], dtype=torch.int, device="cuda")
599+
>>> block_sizes = torch.tensor([[5, 15, 10, 20]], dtype=torch.int, device="cuda")
600+
>>> my_size = 2 # Number of buckets
601+
>>> weights_dim = 3 # Dimension of weight values for each index
602+
>>> weights = torch.randn(indices.size(0), weights_dim, dtype=torch.float, device="cuda")
603+
>>> # Invoke with keep_orig_idx=False, bucketize_pos=False, and
604+
>>> # sequence=False
605+
>>> torch.ops.fbgemm.block_bucketize_sparse_features_2d_weights(
606+
>>> lengths,
607+
>>> indices,
608+
>>> bucketize_pos=False,
609+
>>> sequence=False,
610+
>>> block_sizes=block_sizes,
611+
>>> my_size=my_size,
612+
>>> weights=weights,
613+
>>> weights_dim=weights_dim,
614+
>>> keep_orig_idx=False)
615+
""",
616+
)

fbgemm_gpu/fbgemm_gpu/sparse_ops.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,41 @@ def block_bucketize_sparse_features_meta(
485485
)
486486

487487

488+
def block_bucketize_sparse_features_2d_weights_meta(
489+
lengths: torch.Tensor,
490+
indices: torch.Tensor,
491+
bucketize_pos: bool,
492+
sequence: bool,
493+
block_sizes: torch.Tensor,
494+
my_size: int,
495+
weights: torch.Tensor,
496+
weights_dim: int = 1,
497+
batch_size_per_feature: Optional[torch.Tensor] = None,
498+
max_B: int = -1,
499+
block_bucketize_pos: Optional[torch.Tensor] = None,
500+
keep_orig_idx: bool = False,
501+
total_num_blocks: Optional[torch.Tensor] = None,
502+
keep_orig_idx_per_feature: Optional[torch.Tensor] = None,
503+
) -> Tuple[
504+
torch.Tensor,
505+
torch.Tensor,
506+
torch.Tensor,
507+
Optional[torch.Tensor],
508+
Optional[torch.Tensor],
509+
]:
510+
# Output: lengths, indices, weights", pos?, unbucketize_permute?
511+
num_buckets = my_size
512+
num_features = lengths.size(0)
513+
num_values = indices.size(0)
514+
return (
515+
lengths.new_empty([num_buckets * num_features]),
516+
indices.new_empty([num_values]),
517+
weights.new_empty([num_values, weights_dim]),
518+
indices.new_empty([num_values]) if bucketize_pos else None,
519+
indices.new_empty([num_values]),
520+
)
521+
522+
488523
def merge_pooled_embeddings(
489524
pooled_embeddings: List[torch.Tensor],
490525
uncat_dim_size: int,
@@ -1234,6 +1269,10 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None
12341269
"fbgemm::block_bucketize_sparse_features",
12351270
block_bucketize_sparse_features_meta,
12361271
)
1272+
impl_abstract(
1273+
"fbgemm::block_bucketize_sparse_features_2d_weights",
1274+
block_bucketize_sparse_features_2d_weights_meta,
1275+
)
12371276
impl_abstract("fbgemm::merge_pooled_embeddings", merge_pooled_embeddings)
12381277
impl_abstract(
12391278
"fbgemm::permute_sparse_features", permute_sparse_features_abstract

fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,54 @@ block_bucketize_sparse_features_inference_cpu(
273273
const std::optional<at::Tensor>& total_num_blocks,
274274
const std::optional<at::Tensor>& keep_orig_idx_per_feature);
275275

276+
std::tuple<
277+
at::Tensor,
278+
at::Tensor,
279+
at::Tensor,
280+
std::optional<at::Tensor>,
281+
std::optional<at::Tensor>>
282+
283+
///@ingroup sparse-data-cuda
284+
block_bucketize_sparse_features_2d_weights_cuda(
285+
const at::Tensor& lengths,
286+
const at::Tensor& indices,
287+
const bool bucketize_pos,
288+
const bool sequence,
289+
const at::Tensor& block_sizes,
290+
const int64_t my_size,
291+
const at::Tensor& weights,
292+
const int64_t weights_dim,
293+
const std::optional<at::Tensor>& batch_size_per_feature,
294+
const int64_t max_batch_size,
295+
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
296+
const bool keep_orig_idx,
297+
const std::optional<at::Tensor>& total_num_blocks,
298+
const std::optional<at::Tensor>& keep_orig_idx_per_feature);
299+
300+
std::tuple<
301+
at::Tensor,
302+
at::Tensor,
303+
at::Tensor,
304+
std::optional<at::Tensor>,
305+
std::optional<at::Tensor>>
306+
307+
///@ingroup sparse-data-cpu
308+
block_bucketize_sparse_features_2d_weights_cpu(
309+
const at::Tensor& lengths,
310+
const at::Tensor& indices,
311+
const bool bucketize_pos,
312+
const bool sequence,
313+
const at::Tensor& block_sizes,
314+
const int64_t my_size,
315+
const at::Tensor& weights,
316+
const int64_t weights_dim,
317+
const std::optional<at::Tensor>& batch_size_per_feature,
318+
const int64_t max_batch_size,
319+
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
320+
const bool keep_orig_idx,
321+
const std::optional<at::Tensor>& total_num_blocks,
322+
const std::optional<at::Tensor>& keep_orig_idx_per_feature);
323+
276324
///@ingroup sparse-data-cpu
277325
at::Tensor populate_bucketized_permute_cpu(
278326
const at::Tensor& length,

0 commit comments

Comments
 (0)