Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 1 addition & 19 deletions pkg/controller/runtime/internal/controllerstate/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,24 +223,6 @@ func (adapter *StateAdapter) modify(
emptyResource.Metadata().Namespace(), emptyResource.Metadata().Type(), adapter.Name, emptyResource.Metadata().ID())
}

_, err := adapter.State.Get(ctx, emptyResource.Metadata())
if err != nil {
if state.IsNotFoundError(err) {
err = updateFunc(emptyResource)
if err != nil {
return nil, err
}

if err = adapter.State.Create(ctx, emptyResource, state.WithCreateOwner(adapter.Name)); err != nil {
return nil, err
}

return emptyResource, nil
}

return nil, fmt.Errorf("error querying current object state: %w", err)
}

updateOptions := []state.UpdateOption{state.WithUpdateOwner(adapter.Name)}

modifyOptions := controller.ToModifyOptions(options...)
Expand All @@ -250,7 +232,7 @@ func (adapter *StateAdapter) modify(
updateOptions = append(updateOptions, state.WithExpectedPhaseAny())
}

return adapter.State.UpdateWithConflicts(ctx, emptyResource.Metadata(), updateFunc, updateOptions...)
return adapter.State.ModifyWithResult(ctx, emptyResource, updateFunc, updateOptions...)
}

// AddFinalizer implements controller.Runtime interface.
Expand Down
26 changes: 26 additions & 0 deletions pkg/safe/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,32 @@ func StateWatchKind[T resource.Resource](ctx context.Context, st state.CoreState
return nil
}

// StateModify is a type safe wrapper around state.Modify.
func StateModify[T resource.Resource](ctx context.Context, st state.State, r T, fn func(T) error, options ...state.UpdateOption) error {
return st.Modify(ctx, r, func(r resource.Resource) error {
arg, ok := r.(T)
if !ok {
return fmt.Errorf("type mismatch: expected %T, got %T", arg, r)
}

return fn(arg)
}, options...)
}

// StateModifyWithResult is a type safe wrapper around state.ModifyWithResult.
func StateModifyWithResult[T resource.Resource](ctx context.Context, st state.State, r T, fn func(T) error, options ...state.UpdateOption) (T, error) {
got, err := st.ModifyWithResult(ctx, r, func(r resource.Resource) error {
arg, ok := r.(T)
if !ok {
return fmt.Errorf("type mismatch: expected %T, got %T", arg, r)
}

return fn(arg)
}, options...)

return typeAssertOrZero[T](got, err)
}

// List is a type safe wrapper around resource.List.
type List[T resource.Resource] struct {
list resource.List
Expand Down
98 changes: 98 additions & 0 deletions pkg/state/conformance/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package conformance

import (
"context"
"errors"
"fmt"
"math/rand"
"regexp"
Expand Down Expand Up @@ -1459,6 +1460,103 @@ func (suite *StateSuite) TestTeardownAndDestroy() {
suite.Require().NoError(eg.Wait())
}

// TestModify verifies Modify.
func (suite *StateSuite) TestModify() {
path1 := NewPathResource(suite.getNamespace(), "var/run/modify")

ctx := context.Background()

p1, err := safe.StateModifyWithResult(ctx, suite.State, path1, func(r *PathResource) error {
r.Metadata().Labels().Set("foo", "bar")

return nil
})
suite.Require().NoError(err)

suite.Assert().Equal(resource.String(path1), resource.String(p1))
suite.Assert().Equal("bar", p1.Metadata().Labels().Raw()["foo"])
suite.Assert().Empty(p1.Metadata().Owner())

p2, err := safe.StateGet[*PathResource](ctx, suite.State, path1.Metadata())
suite.Require().NoError(err)

suite.Assert().Equal(resource.String(path1), resource.String(p2))
suite.Assert().Equal("bar", p2.Metadata().Labels().Raw()["foo"])

p1, err = safe.StateModifyWithResult(ctx, suite.State, path1, func(r *PathResource) error {
r.Metadata().Labels().Delete("foo")

return nil
})
suite.Require().NoError(err)

suite.Assert().True(p1.Metadata().Labels().Empty())

p2, err = safe.StateGet[*PathResource](ctx, suite.State, path1.Metadata())
suite.Require().NoError(err)

suite.Assert().True(p2.Metadata().Labels().Empty())

_, err = safe.StateModifyWithResult(ctx, suite.State, path1, func(*PathResource) error {
return errors.New("modify error")
})
suite.Require().EqualError(err, "modify error")

_, err = suite.State.Teardown(ctx, path1.Metadata())
suite.Require().NoError(err)

_, err = safe.StateModifyWithResult(ctx, suite.State, path1, func(r *PathResource) error {
r.Metadata().Labels().Set("foo", "bar")

return nil
})
suite.Require().Error(err)
suite.Assert().True(state.IsPhaseConflictError(err))

p1, err = safe.StateModifyWithResult(ctx, suite.State, path1, func(r *PathResource) error {
r.Metadata().Labels().Set("foo", "bar2")

return nil
}, state.WithExpectedPhaseAny())
suite.Require().NoError(err)
suite.Assert().Equal(resource.PhaseTearingDown, p1.Metadata().Phase())
suite.Assert().Equal("bar2", p1.Metadata().Labels().Raw()["foo"])
}

// TestModifyWithOwner verifies Modify with Owner.
func (suite *StateSuite) TestModifyWithOwner() {
path1 := NewPathResource(suite.getNamespace(), "var/run/modify/owned")

ctx := context.Background()

p1, err := safe.StateModifyWithResult(ctx, suite.State, path1, func(*PathResource) error {
return nil
}, state.WithUpdateOwner("owner"))
suite.Require().NoError(err)

suite.Assert().Equal(resource.String(path1), resource.String(p1))
suite.Assert().Equal("owner", p1.Metadata().Owner())

p1, err = safe.StateModifyWithResult(ctx, suite.State, path1, func(r *PathResource) error {
r.Metadata().Labels().Set("foo", "bar")

return nil
}, state.WithUpdateOwner("owner"))
suite.Require().NoError(err)

suite.Assert().Equal(resource.String(path1), resource.String(p1))
suite.Assert().Equal("owner", p1.Metadata().Owner())
suite.Assert().Equal("bar", p1.Metadata().Labels().Raw()["foo"])

_, err = safe.StateModifyWithResult(ctx, suite.State, path1, func(r *PathResource) error {
r.Metadata().Labels().Set("foo", "baz")

return nil
})
suite.Require().Error(err)
suite.Require().True(state.IsOwnerConflictError(err))
}

func assertContextIsCanceled(t *testing.T, ctx context.Context) { //nolint:revive
t.Helper()

Expand Down
7 changes: 7 additions & 0 deletions pkg/state/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ func WithExpectedPhaseAny() UpdateOption {
}
}

// WithUpdateOptions sets update options for the update request.
func WithUpdateOptions(opts UpdateOptions) UpdateOption {
return func(options *UpdateOptions) {
*options = opts
}
}

// TeardownOptions for the CoreState.Teardown function.
type TeardownOptions struct {
Owner string
Expand Down
10 changes: 10 additions & 0 deletions pkg/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,14 @@ type State interface {
// It's not an error to tear down a resource which is already being torn down.
// The call blocks until all resource finalizers are empty.
TeardownAndDestroy(context.Context, resource.Pointer, ...TeardownAndDestroyOption) error

// Modify modifies an existing resource or creates a new one.
//
// It is a shorthand for Get+UpdateWithConflicts+Create.
Modify(ctx context.Context, emptyResource resource.Resource, updateFunc func(resource.Resource) error, options ...UpdateOption) error

// ModifyWithResult modifies an existing resource or creates a new one.
//
// It is a shorthand for Get+UpdateWithConflicts+Create.
ModifyWithResult(ctx context.Context, emptyResource resource.Resource, updateFunc func(resource.Resource) error, options ...UpdateOption) (resource.Resource, error)
}
49 changes: 49 additions & 0 deletions pkg/state/wrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ package state

import (
"context"
"fmt"

"github.com/siderolabs/go-pointer"

"github.com/cosi-project/runtime/pkg/resource"
)
Expand Down Expand Up @@ -237,3 +240,49 @@ func (state coreWrapper) TeardownAndDestroy(ctx context.Context, resourcePointer

return state.Destroy(ctx, resourcePointer, WithDestroyOwner(options.Owner))
}

// Modify modifies an existing resource or creates a new one.
//
// It is a shorthand for Get+UpdateWithConflicts+Create.
func (state coreWrapper) Modify(
ctx context.Context, emptyResource resource.Resource, updateFunc func(resource.Resource) error, options ...UpdateOption,
) error {
_, err := state.ModifyWithResult(ctx, emptyResource, updateFunc, options...)

return err
}

// ModifyWithResult modifies an existing resource or creates a new one.
//
// It is a shorthand for Get+UpdateWithConflicts+Create.
func (state coreWrapper) ModifyWithResult(
ctx context.Context, emptyResource resource.Resource, updateFunc func(resource.Resource) error, options ...UpdateOption,
) (resource.Resource, error) {
opts := UpdateOptions{
ExpectedPhase: pointer.To(resource.PhaseRunning),
}

for _, opt := range options {
opt(&opts)
}

_, err := state.Get(ctx, emptyResource.Metadata())
if err != nil {
if IsNotFoundError(err) {
err = updateFunc(emptyResource)
if err != nil {
return nil, err
}

if err = state.Create(ctx, emptyResource, WithCreateOwner(opts.Owner)); err != nil {
return nil, err
}

return emptyResource, nil
}

return nil, fmt.Errorf("error querying current object state: %w", err)
}

return state.UpdateWithConflicts(ctx, emptyResource.Metadata(), updateFunc, WithUpdateOptions(opts))
}