Skip to content

Refactoring and moving the EBC-only logic into benchmark_ebc.py #3251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/zch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch

from torchrec import EmbeddingConfig, KeyedJaggedTensor
from torchrec.distributed.benchmark.benchmark_utils import get_inputs
from tqdm import tqdm

from .sparse_arch import SparseArch
Expand Down
14 changes: 8 additions & 6 deletions torchrec/distributed/benchmark/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@
import torch

from torchrec.distributed.benchmark.benchmark_utils import (
benchmark_module,
BenchmarkResult,
CompileMode,
DLRM_NUM_EMBEDDINGS_PER_FEATURE,
EMBEDDING_DIM,
get_tables,
init_argparse_and_args,
write_report,
)
from torchrec.distributed.benchmark.embedding_collection_wrappers import (
benchmark_ebc_module,
get_tables,
)
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
from torchrec.distributed.test_utils.infer_utils import (
TestQuantEBCSharder,
Expand Down Expand Up @@ -84,7 +86,7 @@ def benchmark_qec(args: argparse.Namespace, output_dir: str) -> List[BenchmarkRe
if not argname.startswith("_") and argname not in IGNORE_ARGNAME
}

return benchmark_module(
return benchmark_ebc_module(
module=module,
sharder=sharder,
sharding_types=BENCH_SHARDING_TYPES,
Expand Down Expand Up @@ -118,7 +120,7 @@ def benchmark_qebc(args: argparse.Namespace, output_dir: str) -> List[BenchmarkR
if not argname.startswith("_") and argname not in IGNORE_ARGNAME
}

return benchmark_module(
return benchmark_ebc_module(
module=module,
sharder=sharder,
sharding_types=BENCH_SHARDING_TYPES,
Expand Down Expand Up @@ -153,7 +155,7 @@ def benchmark_qec_unsharded(
if not argname.startswith("_") and argname not in IGNORE_ARGNAME
}

return benchmark_module(
return benchmark_ebc_module(
module=module,
sharder=sharder,
sharding_types=[],
Expand Down Expand Up @@ -190,7 +192,7 @@ def benchmark_qebc_unsharded(
if not argname.startswith("_") and argname not in IGNORE_ARGNAME
}

return benchmark_module(
return benchmark_ebc_module(
module=module,
sharder=sharder,
sharding_types=[],
Expand Down
9 changes: 5 additions & 4 deletions torchrec/distributed/benchmark/benchmark_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
from typing import List, Optional, Tuple

import torch

from torchrec.distributed.benchmark.benchmark_utils import (
benchmark_module,
BenchmarkResult,
CompileMode,
get_tables,
init_argparse_and_args,
set_embedding_config,
write_report,
)
from torchrec.distributed.benchmark.embedding_collection_wrappers import (
benchmark_ebc_module,
get_tables,
)
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
from torchrec.distributed.test_utils.test_model import TestEBCSharder
from torchrec.distributed.types import DataType
Expand Down Expand Up @@ -106,7 +107,7 @@ def benchmark_ebc(

args_kwargs["variable_batch_embeddings"] = variable_batch_embeddings

return benchmark_module(
return benchmark_ebc_module(
module=module,
sharder=sharder,
sharding_types=BENCH_SHARDING_TYPES,
Expand Down
Loading
Loading