Skip to content

Commit 8358f92

Browse files
jamesjwufacebook-github-bot
authored andcommitted
Revamp PT2 Compile/chromium event logging [1/?]
Summary: X-link: pytorch/pytorch#138093 This diff is the starting steps of https://docs.google.com/document/u/2/d/1kAEBt4AyW7HTAhXHbjoz8FBFHNyyEA2Qo2mPn7v3WUQ/edit?usp=drive_web&ouid=113555078003219714709 It implements the following changes: - Only log spans to scuba, so no start events are ever logged - Log events as the full event name, without "START" or "END" - Only log to scuba major phases from chromium events. These are: - entire_frame_compile (dynamo) - backend_compile (aotdispatch) - inductor_compile (inductor) - codegen (inductor codegen) Tlparse chromium events stay basically the same. But I implemented a few changes to clean that up as well: - When there's a phase name available, log the phase name instead of the function name as the event name. This simplifies the trace to not have two identical rows. The fn_name is avaliable as metadata on the chromium event, if interested - Log new events for pre and post grad passes. These do *not* log to scuba. By making the phases much simpler in Scuba, with only categories for major phases of PT2 Compilation, we pave the way to add **much** more metadata and information to each individual event type. Diffs for that will come later. **IMPLEMENTATION NOTES:** - The logic for `log_chromium_event_internal` (which is the function that logs to Scuba) lives in chromium_events for now, but in the future as we add more metadata, it may belong independently in dynamo_timed or even outside of dynamo_timed. I haven't explored in detail what the refactor will look like. Once we start logging metadata for dynamo, aotdispatch, inductor, I suspect we will call log_pt2_compile_event directly, instead of making chromium event logger handle the pt2_compile_event logic. But that refactor is left for another PR on top of this one. - There's an interesting space after pre grad passes within AOT autograd logic, that's between create_aot_dispatcher_function and pre grad passes. I'm not sure what we're spending time doing in that time, but I'll find out with a profile later. ghstack-source-id: 248790387 Reviewed By: oulgen Differential Revision: D64479033 fbshipit-source-id: 1f30e734160bfed2f664063b5b2f4df1b661dfa4
1 parent e89c1b3 commit 8358f92

File tree

1 file changed

+20
-37
lines changed
  • userbenchmark/dynamo/dynamobench/_dynamo

1 file changed

+20
-37
lines changed

userbenchmark/dynamo/dynamobench/_dynamo/utils.py

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -236,16 +236,6 @@ def add_remote_cache_time_saved(time_saved_ns: int, is_backward: bool = False) -
236236
_add_time_spent(key, "remote_cache_time_saved", time_saved)
237237

238238

239-
def get_cache_stats() -> Dict[str, Any]:
240-
"""Get a bunch of metadata about cache hits and misses to use in chromium events"""
241-
cache_stats = {
242-
"fxgraph_cache_hit": counters["inductor"]["fxgraph_cache_hit"],
243-
"fxgraph_cache_miss": counters["inductor"]["fxgraph_cache_miss"],
244-
"fxgraph_cache_bypass": counters["inductor"]["fxgraph_cache_bypass"],
245-
}
246-
return cache_stats
247-
248-
249239
# dynamo_timed is a context manager
250240
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
251241
# where the key is the functions name.
@@ -290,9 +280,10 @@ def dynamo_timed(
290280
try:
291281
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
292282
t0 = time.time()
293-
chromium_log.log_event_start(key, start, None)
294283
if phase_name:
295-
chromium_log.log_event_start(phase_name, start)
284+
chromium_log.log_event_start(phase_name, start, {"fn_name": key})
285+
else:
286+
chromium_log.log_event_start(key, start, {})
296287
yield
297288
time_spent = time.time() - t0
298289
compilation_time_metrics[key].append(time_spent)
@@ -306,16 +297,15 @@ def dynamo_timed(
306297
chromium_log.log_event_end(
307298
phase_name,
308299
time.time_ns(),
309-
{"cache_stats": get_cache_stats()},
300+
{},
310301
start,
311302
)
312-
chromium_log.log_event_end(
313-
key, time.time_ns(), {"cache_stats": get_cache_stats()}, start
314-
)
303+
else:
304+
chromium_log.log_event_end(key, time.time_ns(), {}, start)
315305
# Only record backward compilation metrics if phase_name is not None!
316306
if phase_name:
317307
frame_key = str(curr_frame)
318-
# fwd only compilation stages: entire_frame_compile, backend_compile.
308+
# fwd only compilation stages: entire_frame_compile, backend_compile, aotdispatch.
319309
# use frame_key as time aggregation key.
320310
if fwd_only and fail_type is None:
321311
_add_time_spent(frame_key, phase_name, time_spent)
@@ -902,7 +892,7 @@ def log_event_start(
902892
self,
903893
event_name: str,
904894
time_ns: int,
905-
metadata: Optional[Dict[str, Any]] = None,
895+
metadata: Dict[str, Any],
906896
) -> None:
907897
"""
908898
Logs the start of a single event.
@@ -911,19 +901,14 @@ def log_event_start(
911901
:param metadata: Any extra metadata associated with this event
912902
"""
913903

914-
# Add compile id to metadata
915-
if metadata is None:
916-
metadata = {}
917904
compile_id = str(torch._guards.CompileContext.current_compile_id())
918905
metadata["compile_id"] = compile_id
919-
920-
event = self._log_timed_event(
906+
self._log_timed_event(
921907
event_name,
922908
time_ns,
923909
"B",
924910
metadata,
925911
)
926-
log_chromium_event_internal(event, self.get_stack(), compile_id, self.id_)
927912
self.get_stack().append(event_name)
928913

929914
def reset(self) -> None:
@@ -937,8 +922,8 @@ def log_event_end(
937922
self,
938923
event_name: str,
939924
time_ns: int,
940-
metadata: Optional[Dict[str, Any]] = None,
941-
start_time_ns: Optional[int] = None,
925+
metadata: Dict[str, Any],
926+
start_time_ns: int,
942927
) -> None:
943928
"""
944929
Logs the end of a single event. This function should only be
@@ -947,11 +932,14 @@ def log_event_end(
947932
:param time_ns: Timestamp in nanoseconds
948933
:param metadata: Any extra metadata associated with this event
949934
"""
950-
# Add compile id to metadata
951-
if metadata is None:
952-
metadata = {}
953935
compile_id = str(torch._guards.CompileContext.current_compile_id())
954936
metadata["compile_id"] = compile_id
937+
event = self._log_timed_event(
938+
event_name,
939+
time_ns,
940+
"E",
941+
metadata,
942+
)
955943

956944
# These stack health checks currently never happen,
957945
# but they're written this way to future proof any weird event
@@ -963,13 +951,6 @@ def log_event_end(
963951
log.warning("ChromiumEventLogger: Start event not in stack, ignoring")
964952
return
965953

966-
event = self._log_timed_event(
967-
event_name,
968-
time_ns,
969-
"E",
970-
metadata,
971-
)
972-
973954
while event_name != stack[-1]:
974955
# If the event isn't the most recent one to end, pop
975956
# off the stack until it is.
@@ -1046,7 +1027,9 @@ def log_instant_event(
10461027
expect_trace_id=True,
10471028
)
10481029
# Log an instant event with the same start and end time
1049-
log_chromium_event_internal(event, self.get_stack(), compile_id, self.id_)
1030+
log_chromium_event_internal(
1031+
event, self.get_stack(), compile_id, self.id_, time_ns
1032+
)
10501033

10511034

10521035
CHROMIUM_EVENT_LOG: Optional[ChromiumEventLogger] = None

0 commit comments

Comments
 (0)