@@ -10,6 +10,7 @@ package testing
1010
1111import (
1212 "bytes"
13+ "context"
1314 "errors"
1415 "flag"
1516 "fmt"
@@ -78,6 +79,9 @@ type common struct {
7879 tempDir string
7980 tempDirErr error
8081 tempDirSeq int32
82+
83+ ctx context.Context
84+ cancelCtx context.CancelFunc
8185}
8286
8387type logger struct {
@@ -152,6 +156,7 @@ func fmtDuration(d time.Duration) string {
152156// TB is the interface common to T and B.
153157type TB interface {
154158 Cleanup (func ())
159+ Context () context.Context
155160 Error (args ... interface {})
156161 Errorf (format string , args ... interface {})
157162 Fail ()
@@ -307,6 +312,15 @@ func (c *common) Cleanup(f func()) {
307312 c .cleanups = append (c .cleanups , f )
308313}
309314
315+ // Context returns a context that is canceled just before
316+ // Cleanup-registered functions are called.
317+ //
318+ // Cleanup functions can wait for any resources
319+ // that shut down on [context.Context.Done] before the test or benchmark completes.
320+ func (c * common ) Context () context.Context {
321+ return c .ctx
322+ }
323+
310324// TempDir returns a temporary directory for the test to use.
311325// The directory is automatically removed by Cleanup when the test and
312326// all its subtests complete.
@@ -447,6 +461,9 @@ func (c *common) runCleanup() {
447461 if cleanup == nil {
448462 return
449463 }
464+ if c .cancelCtx != nil {
465+ c .cancelCtx ()
466+ }
450467 cleanup ()
451468 }
452469}
@@ -488,12 +505,15 @@ func (t *T) Run(name string, f func(t *T)) bool {
488505 }
489506
490507 // Create a subtest.
508+ ctx , cancelCtx := context .WithCancel (context .Background ())
491509 sub := T {
492510 common : common {
493- output : & logger {logToStdout : flagVerbose },
494- name : testName ,
495- parent : & t .common ,
496- level : t .level + 1 ,
511+ output : & logger {logToStdout : flagVerbose },
512+ name : testName ,
513+ parent : & t .common ,
514+ level : t .level + 1 ,
515+ ctx : ctx ,
516+ cancelCtx : cancelCtx ,
497517 },
498518 context : t .context ,
499519 }
@@ -606,9 +626,12 @@ func runTests(matchString func(pat, str string) (bool, error), tests []InternalT
606626 ok = true
607627
608628 ctx := newTestContext (newMatcher (matchString , flagRunRegexp , "-test.run" , flagSkipRegexp ))
629+ runCtx , cancelCtx := context .WithCancel (context .Background ())
609630 t := & T {
610631 common : common {
611- output : & logger {logToStdout : flagVerbose },
632+ output : & logger {logToStdout : flagVerbose },
633+ ctx : runCtx ,
634+ cancelCtx : cancelCtx ,
612635 },
613636 context : ctx ,
614637 }
0 commit comments