diff --git a/README.md b/README.md index e27d6a4..6df2394 100644 --- a/README.md +++ b/README.md @@ -139,3 +139,37 @@ export KUBECONFIG=$(pwd)/dev.kubeconfig kubectx proxy kubectl --insecure-skip-tls-verify get namespace ``` + +## Embedded Mode + +The proxy supports an embedded mode that allows direct in-process connections without network overhead. +This is useful for applications that want to embed the proxy functionality directly. + +In embedded mode: +- No TLS/network layer - requests go directly through handlers +- Authentication via configurable HTTP headers (programmatic configuration only) +- High performance with sub-microsecond latency +- Compatible with standard HTTP clients and kubernetes client-go + +Embedded mode is designed for programmatic use when embedding the proxy in Go applications: + +```go +// Basic embedded mode setup +opts := proxy.NewOptions(proxy.WithEmbeddedProxy, proxy.WithEmbeddedSpiceDBEndpoint) + +// Complete configuration +completedConfig, _ := opts.Complete(ctx) +proxySrv, _ := proxy.NewServer(ctx, completedConfig) + +// Get client with automatic authentication headers +client := proxySrv.GetEmbeddedClient( + proxy.WithUser("alice"), + proxy.WithGroups("developers", "admin"), + proxy.WithExtra("department", "engineering"), +) + +// Or get a basic client without authentication +basicClient := proxySrv.GetEmbeddedClient() +``` + +See [docs/embedding.md](docs/embedding.md) for detailed usage examples. diff --git a/docs/embedding.md b/docs/embedding.md new file mode 100644 index 0000000..fdf435f --- /dev/null +++ b/docs/embedding.md @@ -0,0 +1,231 @@ +# Embedding + +The SpiceDB KubeAPI Proxy supports an embedded mode that allows you to integrate the proxy directly into your application without requiring TLS certificates or binding to network ports. +Under the hood, it uses the [`pkg/inmemory`](./pkg/inmemory/) transport package for zero-overhead HTTP communication. + +## Usage + +### Basic Setup + +```go +package main + +import ( + "context" + "net/http" + + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + + "github.com/authzed/spicedb-kubeapi-proxy/pkg/proxy" +) + +func main() { + ctx := context.Background() + + // Create options with embedded mode enabled + opts := proxy.NewOptions(proxy.WithEmbeddedProxy, proxy.WithEmbeddedSpiceDBEndpoint) + + // Configure your backend Kubernetes cluster + opts.RestConfigFunc = func() (*rest.Config, http.RoundTripper, error) { + // Return your cluster's REST config and transport + return myClusterConfig, myTransport, nil + } + + // SpiceDB is already configured for embedded mode via WithEmbeddedSpiceDBEndpoint + + // Set up your authorization rules + opts.RuleConfigFile = "rules.yaml" + + // Complete configuration + if err := opts.Complete(ctx); err != nil { + panic(err) + } + + // Create the proxy server + proxySrv, err := proxy.NewServer(ctx, *opts) + if err != nil { + panic(err) + } + + // Get an HTTP client that connects directly to the embedded proxy + // Use functional options to automatically add authentication headers + embeddedClient := proxySrv.GetEmbeddedClient( + proxy.WithUser("my-user"), + proxy.WithGroups("my-group", "admin"), + proxy.WithExtra("department", "engineering"), + ) + + // Create a Kubernetes client that uses the embedded proxy + k8sClient := createKubernetesClient(embeddedClient) + + // Use the client normally - all requests go through SpiceDB authorization + pods, err := k8sClient.CoreV1().Pods("default").List(ctx, metav1.ListOptions{}) + if err != nil { + panic(err) + } + + fmt.Printf("Found %d pods\n", len(pods.Items)) +} + +func createKubernetesClient(embeddedClient *http.Client) *kubernetes.Clientset { + restConfig := rest.CopyConfig(proxy.EmbeddedRestConfig) + restConfig.Transport = embeddedClient.Transport + + k8sClient, err := kubernetes.NewForConfig(restConfig) + if err != nil { + panic(err) + } + + return k8sClient +} +``` + +### Configuration Options + +## Configuration Options + +You can configure the proxy with different combinations of embedded options: + +### Full Embedded Mode (Proxy + SpiceDB) +```go +// Both proxy and SpiceDB run embedded +opts := proxy.NewOptions(proxy.WithEmbeddedProxy, proxy.WithEmbeddedSpiceDBEndpoint) +``` + +### Embedded Proxy with Remote SpiceDB +```go +// Proxy runs embedded, but connects to remote SpiceDB +opts := proxy.NewOptions(proxy.WithEmbeddedProxy) +opts.SpiceDBOptions.SpiceDBEndpoint = "localhost:50051" +opts.SpiceDBOptions.SecureSpiceDBTokensBySpace = "your-token" +``` + +### Regular Proxy with Embedded SpiceDB +```go +// Proxy runs with TLS termination, but uses embedded SpiceDB +opts := proxy.NewOptions(proxy.WithEmbeddedSpiceDBEndpoint) +``` + +### Example Configuration +```go +opts := proxy.NewOptions(proxy.WithEmbeddedProxy, proxy.WithEmbeddedSpiceDBEndpoint) + +// Backend Kubernetes cluster configuration +opts.RestConfigFunc = func() (*rest.Config, http.RoundTripper, error) { + // Your cluster configuration +} + +// SpiceDB configuration is already set to embedded via WithEmbeddedSpiceDBEndpoint +// For remote SpiceDB, you would instead use: +// opts.SpiceDBOptions.SpiceDBEndpoint = "localhost:50051" +// opts.SpiceDBOptions.SecureSpiceDBTokensBySpace = "your-token" + +// Authorization rules +opts.RuleConfigFile = "path/to/rules.yaml" +``` + +### Authentication Headers + +In embedded mode, authentication is handled via HTTP headers: + +The embedded proxy has a dedicated `EmbeddedAuthentication` configuration that is designed for programmatic use only. When embedding the proxy in your Go application, you can configure the header names through the `opts.Authentication.Embedded` struct: + +- `opts.Authentication.Embedded.UsernameHeaders` +- `opts.Authentication.Embedded.GroupHeaders` +- `opts.Authentication.Embedded.ExtraHeaderPrefixes` + +**Default Headers:** +- `X-Remote-User`: The username (required) +- `X-Remote-Group`: Group membership (can be specified multiple times) +- `X-Remote-Extra-*`: Extra user attributes (e.g., `X-Remote-Extra-Department: engineering`) + +**Example with default headers:** + +``` +X-Remote-User: alice +X-Remote-Group: developers +X-Remote-Group: admin +X-Remote-Extra-Department: engineering +X-Remote-Extra-Team: platform +``` + +**Example with custom headers (programmatic configuration):** + +```go +opts := proxy.NewOptions(proxy.WithEmbeddedProxy, proxy.WithEmbeddedSpiceDBEndpoint) + +// Configure custom header names +opts.Authentication.Embedded.UsernameHeaders = []string{"Custom-User"} +opts.Authentication.Embedded.GroupHeaders = []string{"Custom-Groups"} +opts.Authentication.Embedded.ExtraHeaderPrefixes = []string{"Custom-Extra-"} + +// Complete and create the proxy server +completedConfig, _ := opts.Complete(ctx) +proxySrv, _ := proxy.NewServer(ctx, completedConfig) + +// The client will automatically use the custom header names +embeddedClient := proxySrv.GetEmbeddedClient( + proxy.WithUser("alice"), + proxy.WithGroups("developers", "admin"), + proxy.WithExtra("department", "engineering"), +) +// Headers will be: Custom-User: alice, Custom-Groups: developers, etc. +``` + +This is similar to Kubernetes' request header authentication, but uses a separate dedicated `EmbeddedAuthentication` type for embedded mode and doesn't require client certificate configuration (the requests are trusted because the server is embedded). + +### Functional Options for GetEmbeddedClient + +The `GetEmbeddedClient()` method supports functional options that automatically add authentication headers based on your configured header names. This eliminates the need to manually add headers to each request: + +```go +// Basic client without authentication +client := proxySrv.GetEmbeddedClient() + +// Client with user authentication +client := proxySrv.GetEmbeddedClient( + proxy.WithUser("alice"), +) + +// Client with user and groups +client := proxySrv.GetEmbeddedClient( + proxy.WithUser("alice"), + proxy.WithGroups("developers", "admin", "reviewers"), +) + +// Client with user, groups, and extra attributes +client := proxySrv.GetEmbeddedClient( + proxy.WithUser("alice"), + proxy.WithGroups("developers", "admin"), + proxy.WithExtra("department", "engineering"), + proxy.WithExtra("team", "platform"), + proxy.WithExtra("location", "remote"), +) +``` + +The functional options automatically use the header names you've configured in `opts.Authentication.Embedded`. For example, if you've configured custom header names: + +```go +opts.Authentication.Embedded.UsernameHeaders = []string{"My-User"} +opts.Authentication.Embedded.GroupHeaders = []string{"My-Groups"} +opts.Authentication.Embedded.ExtraHeaderPrefixes = []string{"My-Extra-"} + +// This client will automatically add: +// My-User: alice +// My-Groups: developers +// My-Groups: admin +// My-Extra-department: engineering +client := proxySrv.GetEmbeddedClient( + proxy.WithUser("alice"), + proxy.WithGroups("developers", "admin"), + proxy.WithExtra("department", "engineering"), +) +``` + +Available functional options: +- `WithUser(username string)`: Sets the username +- `WithGroups(groups ...string)`: Sets group memberships +- `WithExtra(key, value string)`: Sets extra user attributes (can be called multiple times) + +This approach provides a clean, type-safe way to configure authentication without manually managing headers. diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index fd3a75c..3b4b2c9 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -19,13 +19,12 @@ import ( "os" "path" "path/filepath" - goruntime "runtime" + "testing" "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "github.com/spf13/afero" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/discovery/cached/disk" "k8s.io/client-go/informers" @@ -40,11 +39,6 @@ import ( "k8s.io/kubernetes/pkg/controller/garbagecollector" "sigs.k8s.io/controller-runtime/pkg/envtest" "sigs.k8s.io/controller-runtime/pkg/log/zap" - "sigs.k8s.io/controller-runtime/tools/setup-envtest/env" - "sigs.k8s.io/controller-runtime/tools/setup-envtest/remote" - "sigs.k8s.io/controller-runtime/tools/setup-envtest/store" - "sigs.k8s.io/controller-runtime/tools/setup-envtest/versions" - "sigs.k8s.io/controller-runtime/tools/setup-envtest/workflows" "github.com/authzed/spicedb-kubeapi-proxy/pkg/authz/distributedtx" "github.com/authzed/spicedb-kubeapi-proxy/pkg/proxy" @@ -118,7 +112,7 @@ var _ = SynchronizedBeforeSuite(func() []byte { Expect(err).To(Succeed()) clientCA = GenerateClientCA(port) - opts := proxy.NewOptions() + opts := proxy.NewOptions(proxy.WithEmbeddedSpiceDBEndpoint) opts.RestConfigFunc = func() (*rest.Config, http.RoundTripper, error) { conf, err := clientcmd.NewDefaultClientConfig(*backendCfg, nil).ClientConfig() if err != nil { @@ -130,7 +124,6 @@ var _ = SynchronizedBeforeSuite(func() []byte { } opts.RuleConfigFile = "rules.yaml" opts.SecureServing.BindPort = port - opts.SpiceDBOptions.SpiceDBEndpoint = proxy.EmbeddedSpiceDBEndpoint opts.SecureServing.BindAddress = net.ParseIP("127.0.0.1") opts.Authentication.BuiltInOptions.ClientCert.ClientCA = clientCA.Path() @@ -154,39 +147,9 @@ var _ = SynchronizedBeforeSuite(func() []byte { func ConfigureApiserver() { log := zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true)) - e := &env.Env{ - Log: log, - Client: &remote.HTTPClient{ - Log: log, - IndexURL: remote.DefaultIndexURL, - }, - Version: versions.Spec{ - Selector: versions.TildeSelector{}, - CheckLatest: false, - }, - VerifySum: true, - ForceDownload: false, - Platform: versions.PlatformItem{ - Platform: versions.Platform{ - OS: goruntime.GOOS, - Arch: goruntime.GOARCH, - }, - }, - FS: afero.Afero{Fs: afero.NewOsFs()}, - Store: store.NewAt("../testbin"), - Out: os.Stdout, - } - var err error - e.Version, err = versions.FromExpr("~1.33.0") - Expect(err).To(Succeed()) - - workflows.Use{ - UseEnv: true, - PrintFormat: env.PrintOverview, - AssetsPath: "../testbin", - }.Do(e) + assetsPath := setupEnvtest(log) - Expect(os.Setenv("KUBEBUILDER_ASSETS", fmt.Sprintf("../testbin/k8s/%s-%s-%s", e.Version.AsConcrete(), e.Platform.OS, e.Platform.Arch))).To(Succeed()) + Expect(os.Setenv("KUBEBUILDER_ASSETS", assetsPath)).To(Succeed()) DeferCleanup(os.Unsetenv, "KUBEBUILDER_ASSETS") } diff --git a/e2e/embedded_integration_test.go b/e2e/embedded_integration_test.go new file mode 100644 index 0000000..4c1313e --- /dev/null +++ b/e2e/embedded_integration_test.go @@ -0,0 +1,259 @@ +//go:build e2e + +package e2e + +import ( + "context" + "errors" + "fmt" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + utilfeature "k8s.io/apiserver/pkg/util/feature" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + logsv1 "k8s.io/component-base/logs/api/v1" + "sigs.k8s.io/controller-runtime/pkg/envtest" + "sigs.k8s.io/controller-runtime/pkg/log/zap" + + "github.com/authzed/spicedb-kubeapi-proxy/pkg/config/proxyrule" + "github.com/authzed/spicedb-kubeapi-proxy/pkg/proxy" + "github.com/authzed/spicedb-kubeapi-proxy/pkg/rules" +) + +// TestEmbeddedModeIntegration tests the full integration of the embedded proxy mode +// with a real Kubernetes API server. +// Note: this is separate from the e2e tests because the setup/teardown is much less +// involved in embedded mode. +func TestEmbeddedModeIntegration(t *testing.T) { + defer require.NoError(t, logsv1.ResetForTest(utilfeature.DefaultFeatureGate)) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Configure the test environment to download binaries if needed + configureApiserver(t) + + // Start a real Kubernetes API server using envtest + testEnv := &envtest.Environment{ + ControlPlaneStopTimeout: 60 * time.Second, + } + + cfg, err := testEnv.Start() + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, testEnv.Stop()) + }) + + // Create custom bootstrap content for the test + bootstrapContent := map[string][]byte{ + "bootstrap.yaml": []byte(`schema: |- + definition cluster {} + definition user {} + definition namespace { + relation cluster: cluster + relation creator: user + relation viewer: user + + permission admin = creator + permission edit = creator + permission view = viewer + creator + permission no_one_at_all = nil + } + definition pod { + relation namespace: namespace + relation creator: user + relation viewer: user + permission edit = creator + permission view = viewer + creator + } + definition testresource { + relation namespace: namespace + relation creator: user + relation viewer: user + permission edit = creator + permission view = viewer + creator + } + definition lock { + relation workflow: workflow + } + definition workflow {} +relationships: | +`), + } + + // Create embedded proxy options with embedded mode enabled and custom bootstrap + opts := proxy.NewOptions(proxy.WithEmbeddedProxy, proxy.WithEmbeddedSpiceDBBootstrap(bootstrapContent)) + + // Configure to use the real test API server + opts.RestConfigFunc = func() (*rest.Config, http.RoundTripper, error) { + transport, err := rest.TransportFor(cfg) + if err != nil { + return nil, nil, err + } + // Make a copy to avoid modifying the original + configCopy := rest.CopyConfig(cfg) + return configCopy, transport, nil + } + + // Create simple rules for namespace operations + createNamespaceRule := proxyrule.Config{ + Spec: proxyrule.Spec{ + Matches: []proxyrule.Match{{ + GroupVersion: "v1", + Resource: "namespaces", + Verbs: []string{"create"}, + }}, + Update: proxyrule.Update{ + CreateRelationships: []proxyrule.StringOrTemplate{{ + Template: "namespace:{{name}}#creator@user:{{user.name}}", + }}, + }, + }, + } + + getNamespaceRule := proxyrule.Config{ + Spec: proxyrule.Spec{ + Matches: []proxyrule.Match{{ + GroupVersion: "v1", + Resource: "namespaces", + Verbs: []string{"get"}, + }}, + Checks: []proxyrule.StringOrTemplate{{ + Template: "namespace:{{name}}#creator@user:{{user.name}}", + }}, + }, + } + + matcher, err := rules.NewMapMatcher([]proxyrule.Config{ + createNamespaceRule, + getNamespaceRule, + }) + require.NoError(t, err) + opts.Matcher = matcher + + // Complete the configuration + completedConfig, err := opts.Complete(ctx) + require.NoError(t, err) + + // Create the embedded proxy server + proxySrv, err := proxy.NewServer(ctx, completedConfig) + require.NoError(t, err) + + // Start the proxy server + go func() { + err := proxySrv.Run(ctx) + if err != nil && !errors.Is(err, context.Canceled) { + t.Errorf("Proxy server failed: %v", err) + } + }() + + // Wait for the server to be ready + require.Eventually(t, func() bool { + httpClient := proxySrv.GetEmbeddedClient( + proxy.WithUser("testuser"), + proxy.WithGroups("users"), + ) + require.NotNil(t, httpClient) + + kubeClient, err := kubernetes.NewForConfigAndClient(proxy.EmbeddedRestConfig, httpClient) + require.NoError(t, err) + + // if ns create works, the server is ready + nsName := "test-namespace-" + fmt.Sprint(time.Now().UnixNano()) + namespace := &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: nsName, + }, + } + _, err = kubeClient.CoreV1().Namespaces().Create(ctx, namespace, metav1.CreateOptions{}) + return err == nil + }, 10*time.Second, 100*time.Millisecond) + + t.Run("namespace creation works with proper authorization", func(t *testing.T) { + // Get embedded client for testuser + httpClient := proxySrv.GetEmbeddedClient( + proxy.WithUser("testuser"), + proxy.WithGroups("users"), + ) + require.NotNil(t, httpClient) + + // Create a Kubernetes client using the embedded HTTP client + kubeClient, err := kubernetes.NewForConfigAndClient(proxy.EmbeddedRestConfig, httpClient) + require.NoError(t, err) + + // Create a test namespace + nsName := "test-namespace-" + fmt.Sprint(time.Now().UnixNano()) + namespace := &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: nsName, + }, + } + + // This should work with proper authorization rules + createdNs, err := kubeClient.CoreV1().Namespaces().Create(ctx, namespace, metav1.CreateOptions{}) + require.NoError(t, err) + + // Verify namespace was created correctly + require.Equal(t, nsName, createdNs.Name) + require.NotEmpty(t, createdNs.ResourceVersion) + + // user can get the namespace they created + retrievedNs, err := kubeClient.CoreV1().Namespaces().Get(ctx, nsName, metav1.GetOptions{}) + require.NoError(t, err) + require.Equal(t, nsName, retrievedNs.Name) + }) + + t.Run("different users get different clients", func(t *testing.T) { + // Get embedded clients for different users + adminClient := proxySrv.GetEmbeddedClient( + proxy.WithUser("admin"), + proxy.WithGroups("system:masters"), + ) + userClient := proxySrv.GetEmbeddedClient( + proxy.WithUser("testuser"), + proxy.WithGroups("developers"), + ) + + require.NotNil(t, adminClient) + require.NotNil(t, userClient) + + // Clients should be different instances + assert.NotEqual(t, adminClient, userClient) + }) + + t.Run("unauthenticated requests are rejected", func(t *testing.T) { + // Get embedded client without authentication + unauthHTTPClient := proxySrv.GetEmbeddedClient() + require.NotNil(t, unauthHTTPClient) + + // Create a Kubernetes client using the unauthenticated HTTP client + kubeClient, err := kubernetes.NewForConfigAndClient(proxy.EmbeddedRestConfig, unauthHTTPClient) + require.NoError(t, err) + + // Try to list namespaces - should be unauthorized + _, err = kubeClient.CoreV1().Namespaces().List(ctx, metav1.ListOptions{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "Unauthorized") + }) +} + +// configureApiserver sets up the test environment binaries for envtest +func configureApiserver(t *testing.T) { + t.Helper() + + // Create a logger compatible with setup-envtest + log := zap.New(zap.UseDevMode(true)) + + // Use the shared setupEnvtest function + assetsPath := setupEnvtest(log) + + // Set the KUBEBUILDER_ASSETS environment variable + t.Setenv("KUBEBUILDER_ASSETS", assetsPath) +} diff --git a/e2e/go.mod b/e2e/go.mod index a808961..f6bd70d 100644 --- a/e2e/go.mod +++ b/e2e/go.mod @@ -12,15 +12,17 @@ require ( github.com/onsi/gomega v1.37.0 github.com/samber/lo v1.50.0 github.com/spf13/afero v1.14.0 + github.com/stretchr/testify v1.10.0 k8s.io/api v0.33.1 - k8s.io/apimachinery v0.34.0-alpha.1 + k8s.io/apimachinery v0.34.0-alpha.2 k8s.io/apiserver v0.33.1 k8s.io/client-go v0.33.1 + k8s.io/component-base v0.33.1 k8s.io/controller-manager v0.33.1 k8s.io/kubernetes v1.33.1 k8s.io/utils v0.0.0-20250604170112-4c0f3b243397 sigs.k8s.io/controller-runtime v0.21.0 - sigs.k8s.io/controller-runtime/tools/setup-envtest v0.0.0-20250617162058-15c5d6129278 + sigs.k8s.io/controller-runtime/tools/setup-envtest v0.0.0-20250708091927-252af6420feb ) require ( @@ -222,7 +224,6 @@ require ( github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect github.com/stoewer/go-strcase v1.3.0 // indirect github.com/stretchr/objx v0.5.2 // indirect - github.com/stretchr/testify v1.10.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/tilinna/z85 v1.0.0 // indirect github.com/warpstreamlabs/bento v1.8.2 // indirect @@ -283,14 +284,13 @@ require ( k8s.io/apiextensions-apiserver v0.33.0 // indirect k8s.io/cloud-provider v0.33.0 // indirect k8s.io/cluster-bootstrap v0.0.0 // indirect - k8s.io/component-base v0.33.1 // indirect k8s.io/component-helpers v0.33.1 // indirect k8s.io/csi-translation-lib v0.0.0 // indirect k8s.io/dynamic-resource-allocation v0.0.0 // indirect k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kms v0.33.1 // indirect k8s.io/kube-controller-manager v0.0.0 // indirect - k8s.io/kube-openapi v0.0.0-20250318190949-c8a335a9a2ff // indirect + k8s.io/kube-openapi v0.0.0-20250610211856-8b98d1ed966a // indirect k8s.io/kubelet v0.33.1 // indirect k8s.io/mount-utils v0.0.0 // indirect k8s.io/pod-security-admission v0.0.0 // indirect diff --git a/e2e/go.sum b/e2e/go.sum index a582744..6e7fb5c 100644 --- a/e2e/go.sum +++ b/e2e/go.sum @@ -2302,8 +2302,8 @@ k8s.io/kms v0.33.1 h1:jJKrFhsbVofpyLF+G8k+drwOAF9CMQpxilHa5Uilb8Q= k8s.io/kms v0.33.1/go.mod h1:C1I8mjFFBNzfUZXYt9FZVJ8MJl7ynFbGgZFbBzkBJ3E= k8s.io/kube-controller-manager v0.33.0 h1:wzd/I2N7X2UU2h3248pyKbyiu4YMGuzWgyRc72YnrWQ= k8s.io/kube-controller-manager v0.33.0/go.mod h1:NtTfXq9CQr1zMgAeIQJjF/Ct2p5K0eP0FhNoD6ePKUU= -k8s.io/kube-openapi v0.0.0-20250318190949-c8a335a9a2ff h1:/usPimJzUKKu+m+TE36gUyGcf03XZEP0ZIKgKj35LS4= -k8s.io/kube-openapi v0.0.0-20250318190949-c8a335a9a2ff/go.mod h1:5jIi+8yX4RIb8wk3XwBo5Pq2ccx4FP10ohkbSKCZoK8= +k8s.io/kube-openapi v0.0.0-20250610211856-8b98d1ed966a h1:ZV3Zr+/7s7aVbjNGICQt+ppKWsF1tehxggNfbM7XnG8= +k8s.io/kube-openapi v0.0.0-20250610211856-8b98d1ed966a/go.mod h1:5jIi+8yX4RIb8wk3XwBo5Pq2ccx4FP10ohkbSKCZoK8= k8s.io/kubelet v0.33.0 h1:4pJA2Ge6Rp0kDNV76KH7pTBiaV2T1a1874QHMcubuSU= k8s.io/kubelet v0.33.0/go.mod h1:iDnxbJQMy9DUNaML5L/WUlt3uJtNLWh7ZAe0JSp4Yi0= k8s.io/kubernetes v1.33.1 h1:86+VVY/f11taZdpEZrNciLw1MIQhu6BFXf/OMFn5EUg= @@ -2374,8 +2374,8 @@ sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.2 h1:jpcvIRr3GLoUo sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.2/go.mod h1:Ve9uj1L+deCXFrPOk1LpFXqTg7LCFzFso6PA48q/XZw= sigs.k8s.io/controller-runtime v0.21.0 h1:CYfjpEuicjUecRk+KAeyYh+ouUBn4llGyDYytIGcJS8= sigs.k8s.io/controller-runtime v0.21.0/go.mod h1:OSg14+F65eWqIu4DceX7k/+QRAbTTvxeQSNSOQpukWM= -sigs.k8s.io/controller-runtime/tools/setup-envtest v0.0.0-20250617162058-15c5d6129278 h1:Ncvj7v4WmmR8IuUQM7jlnRyG8FFuHTn0r0NbLOcCae4= -sigs.k8s.io/controller-runtime/tools/setup-envtest v0.0.0-20250617162058-15c5d6129278/go.mod h1:zCcqn1oG9844T8/vZSYcnqOyoEmTHro4bliTJI6j4OY= +sigs.k8s.io/controller-runtime/tools/setup-envtest v0.0.0-20250708091927-252af6420feb h1:cAM3D5ULjLKzf9vaP1IRCiisMAmIClnGrVLwvJHwTvw= +sigs.k8s.io/controller-runtime/tools/setup-envtest v0.0.0-20250708091927-252af6420feb/go.mod h1:j/Ij4VcmWyLTCr9L2j6cL331xnM0Dw+vOZZR1BuOXXg= sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 h1:/Rv+M11QRah1itp8VhT6HoVx1Ray9eB4DBr+K+/sCJ8= sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3/go.mod h1:18nIHnGi6636UCz6m8i4DhaJ65T6EruyzmoQqI2BVDo= sigs.k8s.io/randfill v0.0.0-20250304075658-069ef1bbf016/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxOY= diff --git a/e2e/util_test.go b/e2e/util_test.go index 329c7f3..4b055db 100644 --- a/e2e/util_test.go +++ b/e2e/util_test.go @@ -4,12 +4,22 @@ package e2e import ( "context" + "fmt" "io" + "os" + goruntime "runtime" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/authzed/spicedb/pkg/tuple" + "github.com/go-logr/logr" . "github.com/onsi/gomega" "github.com/samber/lo" + "github.com/spf13/afero" + "sigs.k8s.io/controller-runtime/tools/setup-envtest/env" + "sigs.k8s.io/controller-runtime/tools/setup-envtest/remote" + "sigs.k8s.io/controller-runtime/tools/setup-envtest/store" + "sigs.k8s.io/controller-runtime/tools/setup-envtest/versions" + "sigs.k8s.io/controller-runtime/tools/setup-envtest/workflows" ) // GetAllTuples collects all tuples matching the filter from SpiceDB @@ -44,9 +54,49 @@ func WriteTuples(ctx context.Context, rels []*v1.Relationship) { Relationship: rel, }) } - + _, err := proxySrv.PermissionClient().WriteRelationships(ctx, &v1.WriteRelationshipsRequest{ Updates: updates, }) Expect(err).To(Succeed()) } + +// setupEnvtest sets up the Kubernetes test binaries using setup-envtest +func setupEnvtest(log logr.Logger) string { + e := &env.Env{ + Log: log, + Client: &remote.HTTPClient{ + Log: log, + IndexURL: remote.DefaultIndexURL, + }, + Version: versions.Spec{ + Selector: versions.TildeSelector{}, + CheckLatest: false, + }, + VerifySum: true, + ForceDownload: false, + Platform: versions.PlatformItem{ + Platform: versions.Platform{ + OS: goruntime.GOOS, + Arch: goruntime.GOARCH, + }, + }, + FS: afero.Afero{Fs: afero.NewOsFs()}, + Store: store.NewAt("../testbin"), + Out: os.Stdout, + } + + version, err := versions.FromExpr("~1.33.0") + if err != nil { + panic(fmt.Sprintf("failed to parse version: %v", err)) + } + e.Version = version + + workflows.Use{ + UseEnv: true, + PrintFormat: env.PrintOverview, + AssetsPath: "../testbin", + }.Do(e) + + return fmt.Sprintf("../testbin/k8s/%s-%s-%s", e.Version.AsConcrete(), e.Platform.OS, e.Platform.Arch) +} diff --git a/go.mod b/go.mod index a9648cf..0600b18 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,8 @@ module github.com/authzed/spicedb-kubeapi-proxy go 1.24.4 +tool github.com/ecordell/optgen + require ( github.com/authzed/authzed-go v1.4.0 github.com/authzed/grpcutil v0.0.0-20240123194739-2ea1e3d2d98b @@ -78,6 +80,7 @@ require ( github.com/coreos/go-semver v0.3.1 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/dalzilio/rudd v1.1.1-0.20230806153452-9e08a6ea8170 // indirect + github.com/dave/jennifer v1.6.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/distribution/reference v0.6.0 // indirect github.com/dlmiddlecote/sqlstats v1.0.2 // indirect @@ -86,6 +89,7 @@ require ( github.com/envoyproxy/go-control-plane/envoy v1.32.4 // indirect github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect github.com/fatih/color v1.18.0 // indirect + github.com/fatih/structtag v1.2.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect diff --git a/go.sum b/go.sum index cb9d5b6..25e615e 100644 --- a/go.sum +++ b/go.sum @@ -752,6 +752,8 @@ github.com/cschleiden/go-workflows v0.16.3-0.20230928210702-d72004e1fdf2 h1:N2oz github.com/cschleiden/go-workflows v0.16.3-0.20230928210702-d72004e1fdf2/go.mod h1:a2TcOFW/byjgukUjo2DAD/Cuqdj/ISgh/PB39r1bdH8= github.com/dalzilio/rudd v1.1.1-0.20230806153452-9e08a6ea8170 h1:bHEN1z3EOO/IXHTQ8ZcmGoW4gTJt+mSrH2Sd458uo0E= github.com/dalzilio/rudd v1.1.1-0.20230806153452-9e08a6ea8170/go.mod h1:IxPC4Bdi3WqUwyGBMgLrWWGx67aRtUAZmOZrkIr7qaM= +github.com/dave/jennifer v1.6.1 h1:T4T/67t6RAA5AIV6+NP8Uk/BIsXgDoqEowgycdQQLuk= +github.com/dave/jennifer v1.6.1/go.mod h1:nXbxhEmQfOZhWml3D1cDK5M1FLnMSozpbFN/m3RmGZc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -811,6 +813,8 @@ github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/fatih/set v0.2.1 h1:nn2CaJyknWE/6txyUDGwysr3G5QC6xWB/PtVjPBbeaA= github.com/fatih/set v0.2.1/go.mod h1:+RKtMCH+favT2+3YecHGxcc0b4KyVWA1QWWJUs4E0CI= +github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= +github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= diff --git a/pkg/authz/authz.go b/pkg/authz/authz.go index 50da4bb..57000b6 100644 --- a/pkg/authz/authz.go +++ b/pkg/authz/authz.go @@ -21,6 +21,11 @@ func WithAuthorization(handler, failed http.Handler, restMapper meta.RESTMapper, ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // in embedded mode, we need to manually set the RequestURI + if req.RequestURI == "" && req.URL != nil { + req.RequestURI = req.URL.RequestURI() + } + input, err := inputExtractor.ExtractFromHttp(req) if err != nil { handleError(w, failed, req, err) diff --git a/pkg/authz/distributedtx/workflow_test.go b/pkg/authz/distributedtx/workflow_test.go index aaf7ea4..086f312 100644 --- a/pkg/authz/distributedtx/workflow_test.go +++ b/pkg/authz/distributedtx/workflow_test.go @@ -31,7 +31,7 @@ func TestWorkflow(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - srv, err := spicedb.NewServer(ctx, "") + srv, err := spicedb.NewServer(ctx, "", nil) require.NoError(t, err) go func() { require.NoError(t, srv.Run(ctx)) diff --git a/pkg/inmemory/README.md b/pkg/inmemory/README.md new file mode 100644 index 0000000..1053728 --- /dev/null +++ b/pkg/inmemory/README.md @@ -0,0 +1,122 @@ +# In-Memory HTTP Transport + +_note: written as if it will be split into its own package_ + +A high-performance, zero-network-overhead HTTP transport implementation that bypasses the network layer entirely by calling handlers directly in-process. + +## Overview + +The `inmemory` package provides an `http.RoundTripper` implementation that directly invokes HTTP handlers in memory during the RoundTrip call, eliminating all network serialization, parsing, and connection overhead. +This is ideal for embedded http services or testing and development environments. + +## Quick Start + +```go +package main + +import ( + "fmt" + "io" + "net/http" + + "github.com/authzed/spicedb-kubeapi-proxy/pkg/inmemory" +) + +func main() { + // Create your HTTP handler + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"message": "Hello, World!"}`)) + }) + + // Create an HTTP client with in-memory transport + client := inmemory.NewClient(handler) + + // Make requests - no network involved! + resp, err := client.Get("http://api.example.com/hello") + if err != nil { + panic(err) + } + defer resp.Body.Close() + + // Headers and status are available immediately + fmt.Printf("Status: %d\n", resp.StatusCode) + fmt.Printf("Content-Type: %s\n", resp.Header.Get("Content-Type")) + + // Read the response body + io.Copy(io.Discard, resp.Body) +} +``` + +## API Reference + +### `New(handler http.Handler) *Transport` + +Creates a new in-memory transport that will invoke the provided handler directly during RoundTrip execution. + +```go +transport := inmemory.New(myHandler) +client := &http.Client{Transport: transport} +``` + +### `NewClient(handler http.Handler) *http.Client` + +Convenience function that creates an HTTP client with an in-memory transport. + +```go +client := inmemory.NewClient(myHandler) +``` + +## Examples + +### Basic Usage + +```go +handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("Hello!")) +}) + +client := inmemory.NewClient(handler) +resp, _ := client.Get("http://example.com/") +body, _ := io.ReadAll(resp.Body) // Response contains "Hello!" +``` + +### With Request Bodies + +```go +handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + w.Write([]byte(fmt.Sprintf("Echo: %s", body))) +}) + +client := inmemory.NewClient(handler) +resp, _ := client.Post("http://example.com/echo", "text/plain", + strings.NewReader("test data")) +body, _ := io.ReadAll(resp.Body) // Response contains "Echo: test data" +``` + +### Complex Handler + +```go +mux := http.NewServeMux() +mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("healthy")) +}) +mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`[{"id": 1, "name": "Alice"}]`)) +}) + +client := inmemory.NewClient(mux) + +// Both endpoints work normally +healthResp, _ := client.Get("http://api.com/health") +usersResp, _ := client.Get("http://api.com/api/users") + +// Handlers executed immediately, read responses +io.ReadAll(healthResp.Body) +io.ReadAll(usersResp.Body) +``` diff --git a/pkg/inmemory/transport.go b/pkg/inmemory/transport.go new file mode 100644 index 0000000..a26c6f1 --- /dev/null +++ b/pkg/inmemory/transport.go @@ -0,0 +1,137 @@ +// Package inmemory provides an in-memory HTTP transport implementation that +// bypasses the network layer entirely by calling handlers directly in-process. +// +// This is useful for embedding HTTP servers directly into applications, +// testing, and high-performance scenarios where network overhead is undesirable. +package inmemory + +import ( + "bytes" + "fmt" + "io" + "net/http" +) + +// Transport implements http.RoundTripper by executing handlers directly +// in-process, eliminating network overhead and providing high performance +// for embedded scenarios. +type Transport struct { + handler http.Handler +} + +// New creates a new in-memory transport that will call the provided handler +// directly during RoundTrip execution. +func New(handler http.Handler) *Transport { + return &Transport{handler: handler} +} + +// RoundTrip implements http.RoundTripper by executing the handler immediately +// and creating a response with the captured body data. +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + if t.handler == nil { + return nil, fmt.Errorf("no handler configured") + } + + // Create the response object + resp := &http.Response{ + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + ContentLength: -1, + Request: req, + } + + // Create capturing response writer + capturer := &responseWriter{ + headers: resp.Header, + response: resp, + } + + // Execute the handler immediately + t.handler.ServeHTTP(capturer, req) + + // Set default status if not set + if resp.StatusCode == 0 { + resp.StatusCode = http.StatusOK + resp.Status = "200 OK" + } + + // Create body reader from captured data + var bodyReader io.Reader + if capturer.bodyData != nil { + bodyReader = bytes.NewReader(capturer.bodyData) + } else { + bodyReader = bytes.NewReader([]byte{}) + } + + resp.Body = &responseBody{reader: bodyReader} + return resp, nil +} + +// responseBody implements io.ReadCloser for response body data +type responseBody struct { + reader io.Reader + closed bool +} + +// Read implements io.Reader +func (b *responseBody) Read(p []byte) (n int, err error) { + if b.closed { + return 0, io.EOF + } + + if b.reader == nil { + return 0, io.EOF + } + + return b.reader.Read(p) +} + +// Close implements io.Closer +func (b *responseBody) Close() error { + b.closed = true + return nil +} + +// responseWriter implements http.ResponseWriter by capturing all writes +type responseWriter struct { + headers http.Header + response *http.Response + bodyData []byte +} + +// Header returns the response headers +func (w *responseWriter) Header() http.Header { + return w.headers +} + +// Write captures body data +func (w *responseWriter) Write(data []byte) (int, error) { + // Set default status on first write if not set + if w.response.StatusCode == 0 { + w.WriteHeader(http.StatusOK) + } + + // Append to body data + w.bodyData = append(w.bodyData, data...) + return len(data), nil +} + +// WriteHeader sets the response status +func (w *responseWriter) WriteHeader(statusCode int) { + // Only set status once (standard behavior) + if w.response.StatusCode != 0 { + return + } + + w.response.StatusCode = statusCode + w.response.Status = fmt.Sprintf("%d %s", statusCode, http.StatusText(statusCode)) +} + +// NewClient creates an http.Client that uses the in-memory transport +func NewClient(handler http.Handler) *http.Client { + return &http.Client{ + Transport: New(handler), + } +} diff --git a/pkg/inmemory/transport_test.go b/pkg/inmemory/transport_test.go new file mode 100644 index 0000000..f12a02e --- /dev/null +++ b/pkg/inmemory/transport_test.go @@ -0,0 +1,448 @@ +package inmemory + +import ( + "fmt" + "io" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTransportBasic(t *testing.T) { + executed := false + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + executed = true + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"message": "hello world"}`)) + }) + + transport := New(handler) + require.NotNil(t, transport) + + req, err := http.NewRequest("GET", "http://example.com/test", nil) + require.NoError(t, err) + + // Handler should not be executed yet + require.False(t, executed) + + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + require.NotNil(t, resp) + + // Handler should be executed immediately during RoundTrip + require.True(t, executed) + + defer resp.Body.Close() + + // Status and headers should be available immediately + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, "200 OK", resp.Status) + require.Equal(t, "application/json", resp.Header.Get("Content-Type")) + + // Read the body + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, `{"message": "hello world"}`, string(body)) +} + +func TestTransportImmediateExecution(t *testing.T) { + executed := false + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + executed = true + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Custom", "test-value") + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{"message": "hello"}`)) + }) + + transport := New(handler) + req, err := http.NewRequest("GET", "http://example.com/test", nil) + require.NoError(t, err) + + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + + // Handler should be executed immediately during RoundTrip + require.True(t, executed) + + // Headers and status should be available immediately + require.Equal(t, "application/json", resp.Header.Get("Content-Type")) + require.Equal(t, "test-value", resp.Header.Get("X-Custom")) + require.Equal(t, http.StatusCreated, resp.StatusCode) + require.Equal(t, "201 Created", resp.Status) + + // Multiple header access should work + require.Equal(t, []string{"test-value"}, resp.Header.Values("X-Custom")) + + resp.Body.Close() +} + +func TestTransportWithClient(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/health": + w.WriteHeader(http.StatusOK) + w.Write([]byte("healthy")) + case "/echo": + body, _ := io.ReadAll(r.Body) + w.Header().Set("Echo-Method", r.Method) + w.Write(body) + default: + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("not found")) + } + }) + + client := NewClient(handler) + require.NotNil(t, client) + + // Test GET request + resp, err := client.Get("http://example.com/health") + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "healthy", string(body)) + + // Test POST request + resp, err = client.Post("http://example.com/echo", "text/plain", strings.NewReader("hello")) + require.NoError(t, err) + defer resp.Body.Close() + + body, err = io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, "POST", resp.Header.Get("Echo-Method")) + require.Equal(t, "hello", string(body)) +} + +func TestTransportStatusCodes(t *testing.T) { + testCases := []struct { + name string + handlerStatus int + expectedStatus int + }{ + {"OK", http.StatusOK, http.StatusOK}, + {"Created", http.StatusCreated, http.StatusCreated}, + {"BadRequest", http.StatusBadRequest, http.StatusBadRequest}, + {"Unauthorized", http.StatusUnauthorized, http.StatusUnauthorized}, + {"NotFound", http.StatusNotFound, http.StatusNotFound}, + {"InternalServerError", http.StatusInternalServerError, http.StatusInternalServerError}, + {"NoStatusSet", 0, http.StatusOK}, // Default when no status is set + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if tc.handlerStatus != 0 { + w.WriteHeader(tc.handlerStatus) + } + w.Write([]byte("response")) + }) + + transport := New(handler) + req, err := http.NewRequest("GET", "http://example.com/test", nil) + require.NoError(t, err) + + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Read body to trigger execution + io.ReadAll(resp.Body) + + require.Equal(t, tc.expectedStatus, resp.StatusCode) + }) + } +} + +func TestTransportHeaders(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Echo request headers as response headers + for k, v := range r.Header { + w.Header()[fmt.Sprintf("Echo-%s", k)] = v + } + w.Header().Set("Custom-Header", "custom-value") + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("ok")) + }) + + transport := New(handler) + req, err := http.NewRequest("GET", "http://example.com/test", nil) + require.NoError(t, err) + + req.Header.Set("Authorization", "Bearer token123") + req.Header.Set("X-Custom", "test-value") + + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Read body to trigger execution + io.ReadAll(resp.Body) + + require.Equal(t, http.StatusAccepted, resp.StatusCode) + require.Equal(t, "custom-value", resp.Header.Get("Custom-Header")) + require.Equal(t, "Bearer token123", resp.Header.Get("Echo-Authorization")) + require.Equal(t, "test-value", resp.Header.Get("Echo-X-Custom")) +} + +func TestTransportLargeResponse(t *testing.T) { + // Test with a large response + largeData := strings.Repeat("abcdefghij", 10000) // 100KB + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte(largeData)) + }) + + transport := New(handler) + req, err := http.NewRequest("GET", "http://example.com/large", nil) + require.NoError(t, err) + + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, largeData, string(body)) +} + +func TestTransportMultipleReads(t *testing.T) { + executionCount := 0 + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + executionCount++ + w.WriteHeader(http.StatusOK) + w.Write([]byte("test response")) + }) + + transport := New(handler) + req, err := http.NewRequest("GET", "http://example.com/test", nil) + require.NoError(t, err) + + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Handler should be executed once during RoundTrip + require.Equal(t, 1, executionCount) + + // Read in chunks to test multiple reads + buf := make([]byte, 4) + + // First read + n, err := resp.Body.Read(buf) + require.NoError(t, err) + require.Equal(t, 4, n) + require.Equal(t, "test", string(buf)) + require.Equal(t, 1, executionCount) // Still 1 + + // Second read + n, err = resp.Body.Read(buf) + require.NoError(t, err) + require.Equal(t, 4, n) + require.Equal(t, " res", string(buf)) + require.Equal(t, 1, executionCount) // Still 1 + + // Read the rest + remaining, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "ponse", string(remaining)) + require.Equal(t, 1, executionCount) // Still 1 +} + +func TestTransportRequestBody(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Set("Echo-Content-Length", fmt.Sprintf("%d", len(body))) + w.Header().Set("Echo-Content-Type", r.Header.Get("Content-Type")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf("Received: %s", string(body)))) + }) + + transport := New(handler) + + testBody := `{"test": "data", "number": 42}` + req, err := http.NewRequest("POST", "http://example.com/echo", strings.NewReader(testBody)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Read body to trigger execution + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, "application/json", resp.Header.Get("Echo-Content-Type")) + require.Equal(t, fmt.Sprintf("%d", len(testBody)), resp.Header.Get("Echo-Content-Length")) + require.Equal(t, fmt.Sprintf("Received: %s", testBody), string(body)) +} + +func TestTransportNilHandler(t *testing.T) { + transport := New(nil) + req, err := http.NewRequest("GET", "http://example.com/test", nil) + require.NoError(t, err) + + resp, err := transport.RoundTrip(req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "no handler configured") +} + +func TestTransportCloseWithoutRead(t *testing.T) { + executed := false + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + executed = true + w.WriteHeader(http.StatusOK) + w.Write([]byte("test")) + }) + + transport := New(handler) + req, err := http.NewRequest("GET", "http://example.com/test", nil) + require.NoError(t, err) + + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + + // Handler should be executed during RoundTrip + require.True(t, executed) + + // Close without reading + err = resp.Body.Close() + require.NoError(t, err) + + // Try to read after close - should get EOF + buf := make([]byte, 10) + n, err := resp.Body.Read(buf) + require.Equal(t, 0, n) + require.Equal(t, io.EOF, err) +} + +func TestTransportMultipleHeaders(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Add multiple values for the same header + w.Header().Add("Set-Cookie", "cookie1=value1") + w.Header().Add("Set-Cookie", "cookie2=value2") + w.Header().Add("X-Custom", "value1") + w.Header().Add("X-Custom", "value2") + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + }) + + transport := New(handler) + req, err := http.NewRequest("GET", "http://example.com/test", nil) + require.NoError(t, err) + + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Read body to trigger execution + io.ReadAll(resp.Body) + + require.Equal(t, http.StatusOK, resp.StatusCode) + + cookies := resp.Header.Values("Set-Cookie") + require.Len(t, cookies, 2) + require.Contains(t, cookies, "cookie1=value1") + require.Contains(t, cookies, "cookie2=value2") + + customs := resp.Header.Values("X-Custom") + require.Len(t, customs, 2) + require.Contains(t, customs, "value1") + require.Contains(t, customs, "value2") +} + +// Benchmark tests +func BenchmarkTransport(b *testing.B) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "ok"}`)) + }) + + transport := New(handler) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + req, err := http.NewRequest("GET", "http://example.com/test", nil) + if err != nil { + b.Fatal(err) + } + + resp, err := transport.RoundTrip(req) + if err != nil { + b.Fatal(err) + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + }) +} + +func BenchmarkTransportWithBody(b *testing.B) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + w.WriteHeader(http.StatusOK) + w.Write(body) + }) + + transport := New(handler) + testData := strings.Repeat("test data ", 100) // ~900 bytes + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + req, err := http.NewRequest("POST", "http://example.com/echo", strings.NewReader(testData)) + if err != nil { + b.Fatal(err) + } + + resp, err := transport.RoundTrip(req) + if err != nil { + b.Fatal(err) + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + }) +} + +func BenchmarkTransportLargeResponse(b *testing.B) { + // 1MB response to test memory efficiency + largeData := strings.Repeat("abcdefghij", 100000) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte(largeData)) + }) + + transport := New(handler) + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + req, _ := http.NewRequest("GET", "http://example.com/large", nil) + resp, _ := transport.RoundTrip(req) + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } +} diff --git a/pkg/proxy/authn.go b/pkg/proxy/authn.go index 9f68066..500175a 100644 --- a/pkg/proxy/authn.go +++ b/pkg/proxy/authn.go @@ -4,15 +4,36 @@ import ( "context" "fmt" "net/http" + "strings" "github.com/spf13/pflag" "k8s.io/apiserver/pkg/authentication/authenticator" + "k8s.io/apiserver/pkg/authentication/user" genericapiserver "k8s.io/apiserver/pkg/server" kubeoptions "k8s.io/kubernetes/pkg/kubeapiserver/options" ) +// EmbeddedAuthentication configures authentication for embedded mode +type EmbeddedAuthentication struct { + Enabled bool + UsernameHeaders []string + GroupHeaders []string + ExtraHeaderPrefixes []string +} + +// NewEmbeddedAuthentication creates a new embedded authentication configuration with defaults +func NewEmbeddedAuthentication() *EmbeddedAuthentication { + return &EmbeddedAuthentication{ + Enabled: false, + UsernameHeaders: []string{"X-Remote-User"}, + GroupHeaders: []string{"X-Remote-Group"}, + ExtraHeaderPrefixes: []string{"X-Remote-Extra-"}, + } +} + type Authentication struct { BuiltInOptions *kubeoptions.BuiltInAuthenticationOptions + Embedded EmbeddedAuthentication } func NewAuthentication() *Authentication { @@ -24,6 +45,7 @@ func NewAuthentication() *Authentication { // WithServiceAccounts(). WithTokenFile(). WithRequestHeader(), + Embedded: *NewEmbeddedAuthentication(), } // TODO: ServiceAccounts // auth.BuiltInOptions.ServiceAccounts.Issuers = []string{"https://spicedb-kubeapi-proxy.default.svc"} @@ -31,7 +53,7 @@ func NewAuthentication() *Authentication { } func (c *Authentication) AdditionalAuthEnabled() bool { - return c.tokenAuthEnabled() || c.serviceAccountAuthEnabled() || c.oidcAuthEnabled() + return c.tokenAuthEnabled() || c.serviceAccountAuthEnabled() || c.oidcAuthEnabled() || c.Embedded.Enabled } func (c *Authentication) oidcAuthEnabled() bool { @@ -47,6 +69,56 @@ func (c *Authentication) serviceAccountAuthEnabled() bool { } func (c *Authentication) ApplyTo(ctx context.Context, authenticationInfo *genericapiserver.AuthenticationInfo, servingInfo *genericapiserver.SecureServingInfo) error { + if c.Embedded.Enabled { + // For embedded mode, use dedicated embedded authentication configuration + usernameHeaders := c.Embedded.UsernameHeaders + groupHeaders := c.Embedded.GroupHeaders + extraHeaderPrefixes := c.Embedded.ExtraHeaderPrefixes + + authenticationInfo.Authenticator = authenticator.RequestFunc(func(req *http.Request) (*authenticator.Response, bool, error) { + // Try username headers in order + var username string + for _, header := range usernameHeaders { + if value := req.Header.Get(header); value != "" { + username = value + break + } + } + if username == "" { + return nil, false, nil + } + + // Collect groups from all group headers + var groups []string + for _, header := range groupHeaders { + groups = append(groups, req.Header.Values(header)...) + } + + // Collect extra fields + extra := make(map[string][]string) + for key, values := range req.Header { + for _, prefix := range extraHeaderPrefixes { + if strings.HasPrefix(key, prefix) { + extraKey := strings.TrimPrefix(key, prefix) + // Convert to lowercase as per Kubernetes convention + extraKey = strings.ToLower(extraKey) + extra[extraKey] = values + break + } + } + } + + return &authenticator.Response{ + User: &user.DefaultInfo{ + Name: username, + Groups: groups, + Extra: extra, + }, + }, true, nil + }) + return nil + } + authenticatorConfig, err := c.BuiltInOptions.ToAuthenticationConfig() if err != nil { return err diff --git a/pkg/proxy/authn_test.go b/pkg/proxy/authn_test.go index 6219dba..9a71bed 100644 --- a/pkg/proxy/authn_test.go +++ b/pkg/proxy/authn_test.go @@ -95,7 +95,7 @@ func runProxyRequest(t testing.TB, ctx context.Context, headers map[string][]str AllowedNames: []string{"service"}, } - opts := NewOptions() + opts := NewOptions(WithEmbeddedSpiceDBEndpoint) opts.RestConfigFunc = func() (*rest.Config, http.RoundTripper, error) { ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -164,7 +164,6 @@ func runProxyRequest(t testing.TB, ctx context.Context, headers map[string][]str return rc, transport, nil } - opts.SpiceDBOptions.SpiceDBEndpoint = EmbeddedSpiceDBEndpoint opts.SecureServing.ServerCert.CertKey = certStore.servingCertKey opts.SecureServing.BindAddress = net.ParseIP("127.0.0.1") opts.SecureServing.BindPort = port diff --git a/pkg/proxy/embedded.go b/pkg/proxy/embedded.go new file mode 100644 index 0000000..ff0689c --- /dev/null +++ b/pkg/proxy/embedded.go @@ -0,0 +1,13 @@ +package proxy + +import ( + "k8s.io/client-go/rest" +) + +// EmbeddedRestConfig is the standard REST config for embedded mode. +// This config uses a special "http://embedded" host URL that signals +// to the embedded client transport that requests should be handled +// in-memory rather than over the network. +var EmbeddedRestConfig = &rest.Config{ + Host: "http://embedded", +} diff --git a/pkg/proxy/embedded_test.go b/pkg/proxy/embedded_test.go new file mode 100644 index 0000000..117c048 --- /dev/null +++ b/pkg/proxy/embedded_test.go @@ -0,0 +1,473 @@ +package proxy + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apiserver/pkg/endpoints/request" + utilfeature "k8s.io/apiserver/pkg/util/feature" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + logsv1 "k8s.io/component-base/logs/api/v1" + + "github.com/authzed/spicedb-kubeapi-proxy/pkg/rules" +) + +func TestEmbeddedMode(t *testing.T) { + defer require.NoError(t, logsv1.ResetForTest(utilfeature.DefaultFeatureGate)) + ctx := t.Context() + + opts := createEmbeddedTestOptions(t) + completedConfig, err := opts.Complete(ctx) + require.NoError(t, err) + + proxySrv, err := NewServer(ctx, completedConfig) + require.NoError(t, err) + + t.Run("basic embedded client", func(t *testing.T) { + // Get embedded client + client := proxySrv.GetEmbeddedClient() + require.NotNil(t, client, "embedded client should not be nil") + + // Test basic request (health endpoint doesn't require auth) + req, err := http.NewRequestWithContext(ctx, "GET", "http://embedded/healthz", nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Should get a response (may be 404 but not connection error) + require.NotEqual(t, 0, resp.StatusCode, "should get a status code") + + // Read body + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + }) + + t.Run("kubernetes client integration", func(t *testing.T) { + embeddedClient := proxySrv.GetEmbeddedClient() + require.NotNil(t, embeddedClient) + + // Create Kubernetes client using embedded transport + k8sClient := createKubernetesClient(t, embeddedClient, "admin-user", []string{"admin"}) + + // Make a simple API call through the client + // This tests that the kubernetes client-go library works with our embedded transport + _, err = k8sClient.CoreV1().Namespaces().List(ctx, metav1.ListOptions{Limit: 1}) + // We expect this to fail since we don't have a real API server + // but it should go through our proxy without connection errors + require.Error(t, err) + // The error should not be a connection refused error + require.NotContains(t, err.Error(), "connection refused") + }) + +} + +func TestEmbeddedModeCustomHeaders(t *testing.T) { + defer require.NoError(t, logsv1.ResetForTest(utilfeature.DefaultFeatureGate)) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Create a proxy with custom header names + opts := createEmbeddedTestOptions(t) + // Override with custom header names + opts.Authentication.Embedded.UsernameHeaders = []string{"Custom-User"} + opts.Authentication.Embedded.GroupHeaders = []string{"Custom-Groups"} + opts.Authentication.Embedded.ExtraHeaderPrefixes = []string{"Custom-Extra-"} + + completedConfig, err := opts.Complete(ctx) + require.NoError(t, err) + + proxySrv, err := NewServer(ctx, completedConfig) + require.NoError(t, err) + + client := proxySrv.GetEmbeddedClient() + require.NotNil(t, client) + + // Test request with custom headers + req, err := http.NewRequestWithContext(ctx, "GET", "http://embedded/healthz", nil) + require.NoError(t, err) + + // Use custom header names + req.Header.Set("Custom-User", "test-user") + req.Header.Set("Custom-Groups", "test-group") + req.Header.Set("Custom-Extra-Department", "engineering") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Should get a response + require.NotEqual(t, 0, resp.StatusCode, "should get a status code") + + // Read body + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) +} + +func TestEmbeddedModeAuthenticationConfiguration(t *testing.T) { + defer require.NoError(t, logsv1.ResetForTest(utilfeature.DefaultFeatureGate)) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Create one proxy with multiple header configuration for all tests + opts := createEmbeddedTestOptions(t) + // Configure multiple headers of each type + opts.Authentication.Embedded.UsernameHeaders = []string{"Primary-User", "Secondary-User", "X-Remote-User"} + opts.Authentication.Embedded.GroupHeaders = []string{"X-Remote-Group", "X-User-Groups"} + opts.Authentication.Embedded.ExtraHeaderPrefixes = []string{"X-Remote-Extra-", "X-User-Attr-"} + + completedConfig, err := opts.Complete(ctx) + require.NoError(t, err) + + proxySrv, err := NewServer(ctx, completedConfig) + require.NoError(t, err) + + client := proxySrv.GetEmbeddedClient() + require.NotNil(t, client) + + t.Run("username header priority", func(t *testing.T) { + // Test with secondary header (Primary-User missing) + req, err := http.NewRequestWithContext(ctx, "GET", "http://embedded/healthz", nil) + require.NoError(t, err) + + req.Header.Set("Secondary-User", "alice") + req.Header.Set("X-Remote-Group", "admin") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.NotEqual(t, 0, resp.StatusCode, "should get a status code") + + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + }) + + t.Run("multiple group headers", func(t *testing.T) { + // Test with groups from both headers + req, err := http.NewRequestWithContext(ctx, "GET", "http://embedded/healthz", nil) + require.NoError(t, err) + + req.Header.Set("X-Remote-User", "bob") + req.Header.Add("X-Remote-Group", "developers") + req.Header.Add("X-Remote-Group", "reviewers") + req.Header.Add("X-User-Groups", "admin") + req.Header.Add("X-User-Groups", "security") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.NotEqual(t, 0, resp.StatusCode, "should get a status code") + + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + }) + + t.Run("multiple extra header prefixes", func(t *testing.T) { + // Test with extra attributes from both prefixes + req, err := http.NewRequestWithContext(ctx, "GET", "http://embedded/healthz", nil) + require.NoError(t, err) + + req.Header.Set("X-Remote-User", "charlie") + req.Header.Set("X-Remote-Group", "engineers") + req.Header.Set("X-Remote-Extra-Department", "platform") + req.Header.Set("X-Remote-Extra-Team", "infrastructure") + req.Header.Set("X-User-Attr-Location", "remote") + req.Header.Set("X-User-Attr-Timezone", "UTC") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.NotEqual(t, 0, resp.StatusCode, "should get a status code") + + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + }) +} + +func TestEmbeddedModeDefaults(t *testing.T) { + defer require.NoError(t, logsv1.ResetForTest(utilfeature.DefaultFeatureGate)) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Create embedded proxy with no explicit header configuration to test defaults + opts := NewOptions(WithEmbeddedProxy, WithEmbeddedSpiceDBEndpoint) + opts.Authentication.Embedded.Enabled = true + + // Configure mock upstream server + opts.RestConfigFunc = func() (*rest.Config, http.RoundTripper, error) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"kind": "Status", "status": "Success"}`)) + })) + t.Cleanup(mockServer.Close) + + return &rest.Config{ + Host: mockServer.URL, + TLSClientConfig: rest.TLSClientConfig{ + Insecure: true, + }, + }, nil, nil + } + + // Use empty rules for testing - allow all requests + opts.Matcher = rules.MatcherFunc(func(match *request.RequestInfo) []*rules.RunnableRule { + return []*rules.RunnableRule{{ + Checks: []*rules.RelExpr{}, + }} + }) + + completedConfig, err := opts.Complete(ctx) + require.NoError(t, err) + + proxySrv, err := NewServer(ctx, completedConfig) + require.NoError(t, err) + + client := proxySrv.GetEmbeddedClient() + require.NotNil(t, client) + + // Test with default headers (X-Remote-User, X-Remote-Group, X-Remote-Extra-) + req, err := http.NewRequestWithContext(ctx, "GET", "http://embedded/healthz", nil) + require.NoError(t, err) + + req.Header.Set("X-Remote-User", "default-user") + req.Header.Set("X-Remote-Group", "default-group") + req.Header.Set("X-Remote-Extra-Department", "engineering") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Should get a response + require.NotEqual(t, 0, resp.StatusCode, "should get a status code") + + // Read body + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) +} + +func TestEmbeddedClientFunctionalOptions(t *testing.T) { + defer require.NoError(t, logsv1.ResetForTest(utilfeature.DefaultFeatureGate)) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Create one proxy server for all subtests to avoid logging config issues + opts := createEmbeddedTestOptions(t) + completedConfig, err := opts.Complete(ctx) + require.NoError(t, err) + + proxySrv, err := NewServer(ctx, completedConfig) + require.NoError(t, err) + + t.Run("basic client without options", func(t *testing.T) { + client := proxySrv.GetEmbeddedClient() + require.NotNil(t, client) + + // Basic client should not add any authentication headers automatically + req, err := http.NewRequestWithContext(ctx, "GET", "http://embedded/healthz", nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.NotEqual(t, 0, resp.StatusCode) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + }) + + t.Run("transport adds headers correctly", func(t *testing.T) { + // Test the authHeaderTransport directly + baseTransport := &testTransport{} + transport := &authHeaderTransport{ + base: baseTransport, + username: "test-user", + groups: []string{"developers", "admin"}, + extra: map[string]string{"department": "engineering", "location": "remote"}, + usernameHeaders: []string{"X-Remote-User"}, + groupHeaders: []string{"X-Remote-Group"}, + extraHeaderPrefixes: []string{"X-Remote-Extra-"}, + } + + req, err := http.NewRequest("GET", "http://example.com/test", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + // Check that baseTransport received the request with added headers + capturedReq := baseTransport.lastRequest + require.NotNil(t, capturedReq) + + // Check username header + require.Equal(t, "test-user", capturedReq.Header.Get("X-Remote-User")) + + // Check group headers + groups := capturedReq.Header.Values("X-Remote-Group") + require.Contains(t, groups, "developers") + require.Contains(t, groups, "admin") + require.Len(t, groups, 2) + + // Check extra headers + require.Equal(t, "engineering", capturedReq.Header.Get("X-Remote-Extra-department")) + require.Equal(t, "remote", capturedReq.Header.Get("X-Remote-Extra-location")) + }) + + t.Run("functional options create correct transport", func(t *testing.T) { + client := proxySrv.GetEmbeddedClient( + WithUser("alice"), + WithGroups("security", "reviewers"), + WithExtra("team", "platform"), + ) + require.NotNil(t, client) + + // Check that the transport is wrapped + transport, ok := client.Transport.(*authHeaderTransport) + require.True(t, ok, "transport should be wrapped with authHeaderTransport") + + // Check configuration + require.Equal(t, "alice", transport.username) + require.Equal(t, []string{"security", "reviewers"}, transport.groups) + require.Equal(t, "platform", transport.extra["team"]) + require.Equal(t, []string{"X-Remote-User"}, transport.usernameHeaders) + require.Equal(t, []string{"X-Remote-Group"}, transport.groupHeaders) + require.Equal(t, []string{"X-Remote-Extra-"}, transport.extraHeaderPrefixes) + }) + +} + +func TestEmbeddedClientCustomHeaderConfig(t *testing.T) { + defer require.NoError(t, logsv1.ResetForTest(utilfeature.DefaultFeatureGate)) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Create proxy with custom header names + customOpts := createEmbeddedTestOptions(t) + customOpts.Authentication.Embedded.UsernameHeaders = []string{"Custom-User"} + customOpts.Authentication.Embedded.GroupHeaders = []string{"Custom-Groups"} + customOpts.Authentication.Embedded.ExtraHeaderPrefixes = []string{"Custom-Extra-"} + + customCompletedConfig, err := customOpts.Complete(ctx) + require.NoError(t, err) + + customProxySrv, err := NewServer(ctx, customCompletedConfig) + require.NoError(t, err) + + client := customProxySrv.GetEmbeddedClient( + WithUser("charlie"), + WithGroups("security"), + WithExtra("team", "infrastructure"), + ) + require.NotNil(t, client) + + // Check that custom header names are used + transport, ok := client.Transport.(*authHeaderTransport) + require.True(t, ok) + require.Equal(t, []string{"Custom-User"}, transport.usernameHeaders) + require.Equal(t, []string{"Custom-Groups"}, transport.groupHeaders) + require.Equal(t, []string{"Custom-Extra-"}, transport.extraHeaderPrefixes) +} + +// testTransport is a simple transport that captures the last request +type testTransport struct { + lastRequest *http.Request +} + +func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) { + t.lastRequest = req + return &http.Response{ + StatusCode: 200, + Body: http.NoBody, + Header: make(http.Header), + }, nil +} + +// createEmbeddedTestOptions creates minimal options for embedded testing +func createEmbeddedTestOptions(t *testing.T) *Options { + t.Helper() + + opts := NewOptions(WithEmbeddedProxy, WithEmbeddedSpiceDBEndpoint) + opts.Authentication.Embedded.Enabled = true + + // Configure mock upstream server + opts.RestConfigFunc = func() (*rest.Config, http.RoundTripper, error) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"kind": "Status", "status": "Success"}`)) + })) + t.Cleanup(mockServer.Close) + + return &rest.Config{ + Host: mockServer.URL, + TLSClientConfig: rest.TLSClientConfig{ + Insecure: true, + }, + }, nil, nil + } + + // Use empty rules for testing - allow all requests + opts.Matcher = rules.MatcherFunc(func(match *request.RequestInfo) []*rules.RunnableRule { + return []*rules.RunnableRule{{ + Checks: []*rules.RelExpr{}, + }} + }) + + return opts +} + +// createKubernetesClient creates a kubernetes client using the embedded transport +func createKubernetesClient(t *testing.T, embeddedClient *http.Client, username string, groups []string) *kubernetes.Clientset { + t.Helper() + + // Create rest config that uses the embedded transport + restConfig := rest.CopyConfig(EmbeddedRestConfig) + restConfig.Transport = embeddedClient.Transport + + // Wrap transport to add authentication headers + restConfig.Transport = &headerAddingTransport{ + base: embeddedClient.Transport, + username: username, + groups: groups, + } + + clientset, err := kubernetes.NewForConfig(restConfig) + require.NoError(t, err) + + return clientset +} + +// headerAddingTransport wraps an http.RoundTripper to add authentication headers +type headerAddingTransport struct { + base http.RoundTripper + username string + groups []string +} + +func (h *headerAddingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone request to avoid modifying original + newReq := req.Clone(req.Context()) + + // Add authentication headers + newReq.Header.Set("X-Remote-User", h.username) + for _, group := range h.groups { + newReq.Header.Add("X-Remote-Group", group) + } + + return h.base.RoundTrip(newReq) +} diff --git a/pkg/proxy/options.go b/pkg/proxy/options.go index 400a77b..ac3c44e 100644 --- a/pkg/proxy/options.go +++ b/pkg/proxy/options.go @@ -41,8 +41,6 @@ const ( defaultDialerTimeout = 5 * time.Second ) -//go:generate go run github.com/ecordell/optgen -output zz_spicedb_options.go . SpiceDBOptions - type Options struct { SecureServing apiserveroptions.SecureServingOptionsWithLoopback `debugmap:"hidden"` Authentication Authentication `debugmap:"hidden"` @@ -61,6 +59,9 @@ type Options struct { CertDir string `debugmap:"visible"` + // Embedded mode configuration + EmbeddedMode bool `debugmap:"visible"` + AuthenticationInfo genericapiserver.AuthenticationInfo `debugmap:"hidden"` ServingInfo *genericapiserver.SecureServingInfo `debugmap:"hidden"` AdditionalAuthEnabled bool `debugmap:"visible"` @@ -74,9 +75,11 @@ type Options struct { PermissionsClient v1.PermissionsServiceClient `debugmap:"hidden"` } +//go:generate go run github.com/ecordell/optgen -output zz_spicedb_options.go . SpiceDBOptions type SpiceDBOptions struct { SpiceDBEndpoint string `debugmap:"visible"` EmbeddedSpiceDB server.RunnableServer `debugmap:"hidden"` + BootstrapContent map[string][]byte `debugmap:"hidden"` Insecure bool `debugmap:"sensitive"` SkipVerifyCA bool `debugmap:"visible"` SecureSpiceDBTokensBySpace string `debugmap:"sensitive"` @@ -103,7 +106,9 @@ func (so *SpiceDBOptions) AddFlags(fs *pflag.FlagSet) { const tlsCertificatePairName = "tls" -func NewOptions() *Options { +type setOpt func(*Options) + +func NewOptions(opts ...setOpt) *Options { o := &Options{ SecureServing: *apiserveroptions.NewSecureServingOptions().WithLoopback(), Authentication: *NewAuthentication(), @@ -113,9 +118,57 @@ func NewOptions() *Options { o.Logs.Verbosity = logsv1.VerbosityLevel(3) o.SecureServing.BindPort = 443 o.SecureServing.ServerCert.PairName = tlsCertificatePairName + + for _, opt := range opts { + opt(o) + } + return o } +// WithEmbeddedProxy configures the proxy to run in embedded mode. +// In embedded mode, the proxy runs as an HTTP server without TLS termination, +// suitable for use behind a load balancer or ingress controller. +func WithEmbeddedProxy(o *Options) { + o.EmbeddedMode = true +} + +// WithEmbeddedSpiceDBEndpoint configures the proxy to use an embedded SpiceDB instance. +// This creates an in-memory SpiceDB instance that runs within the proxy process. +// Use this for development, testing, or single-node deployments. +func WithEmbeddedSpiceDBEndpoint(o *Options) { + o.SpiceDBOptions.SpiceDBEndpoint = EmbeddedSpiceDBEndpoint +} + +// WithEmbeddedSpiceDBBootstrap configures the proxy to use an embedded SpiceDB instance +// with custom bootstrap content. This allows you to provide your own schema and initial +// relationships directly as a byte slice instead of using a file. +// +// The bootstrap content should be provided as a map with filename as key and YAML content as value: +// +// bootstrapContent := map[string][]byte{ +// "bootstrap.yaml": []byte(`schema: |- +// definition user {} +// definition namespace { +// relation creator: user +// permission view = creator +// } +// definition lock { +// relation workflow: workflow +// } +// definition workflow {} +// relationships: | +// `), +// } +// +// Use this for testing or when you want to programmatically define your SpiceDB schema. +func WithEmbeddedSpiceDBBootstrap(bootstrapContent map[string][]byte) func(*Options) { + return func(o *Options) { + o.SpiceDBOptions.SpiceDBEndpoint = EmbeddedSpiceDBEndpoint + o.SpiceDBOptions.BootstrapContent = bootstrapContent + } +} + func (o *Options) FromRESTConfig(restConfig *rest.Config) *Options { o.OverrideUpstream = false o.UseInClusterConfig = false @@ -216,33 +269,40 @@ func (o *Options) Complete(ctx context.Context) (*CompletedConfig, error) { o.InputExtractor = rules.ResolveInputExtractorFunc(rules.NewResolveInputFromHttp) } - if !filepath.IsAbs(o.SecureServing.ServerCert.CertDirectory) { - o.SecureServing.ServerCert.CertDirectory = filepath.Join(o.CertDir, o.SecureServing.ServerCert.CertDirectory) - } + // Set embedded mode in authentication + o.Authentication.Embedded.Enabled = o.EmbeddedMode - if err := o.SecureServing.MaybeDefaultWithSelfSignedCerts("localhost", []string{"kubernetes.default.svc", "kubernetes.default", "kubernetes"}, nil); err != nil { - return nil, err - } + if !o.EmbeddedMode { + if !filepath.IsAbs(o.SecureServing.ServerCert.CertDirectory) { + o.SecureServing.ServerCert.CertDirectory = filepath.Join(o.CertDir, o.SecureServing.ServerCert.CertDirectory) + } - var loopbackClientConfig *rest.Config - if err := o.SecureServing.ApplyTo(&o.ServingInfo, &loopbackClientConfig); err != nil { - return nil, err + if err := o.SecureServing.MaybeDefaultWithSelfSignedCerts("localhost", []string{"kubernetes.default.svc", "kubernetes.default", "kubernetes"}, nil); err != nil { + return nil, err + } + + var loopbackClientConfig *rest.Config + if err := o.SecureServing.ApplyTo(&o.ServingInfo, &loopbackClientConfig); err != nil { + return nil, err + } } + if err := o.Authentication.ApplyTo(ctx, &o.AuthenticationInfo, o.ServingInfo); err != nil { return nil, err } o.AdditionalAuthEnabled = o.Authentication.AdditionalAuthEnabled() - spicedbURl, err := url.Parse(o.SpiceDBOptions.SpiceDBEndpoint) + spicedbURL, err := url.Parse(o.SpiceDBOptions.SpiceDBEndpoint) if err != nil { return nil, fmt.Errorf("unable to parse SpiceDB endpoint URL: %w", err) } var conn *grpc.ClientConn - if spicedbURl.Scheme == "embedded" { - klog.FromContext(ctx).WithValues("spicedb-endpoint", spicedbURl).Info("using embedded SpiceDB") - o.SpiceDBOptions.EmbeddedSpiceDB, err = spicedb.NewServer(ctx, spicedbURl.Path) + if spicedbURL.Scheme == "embedded" { + klog.FromContext(ctx).WithValues("spicedb-endpoint", spicedbURL).Info("using embedded SpiceDB") + + o.SpiceDBOptions.EmbeddedSpiceDB, err = spicedb.NewServer(ctx, spicedbURL.Path, o.SpiceDBOptions.BootstrapContent) if err != nil { return nil, fmt.Errorf("unable to stand up embedded SpiceDB: %w", err) } @@ -344,7 +404,7 @@ func (o *Options) configFromPath() (*clientcmdapi.Config, error) { func (o *Options) Validate() []error { var errs []error - if len(o.BackendKubeconfigPath) == 0 && !o.UseInClusterConfig { + if len(o.BackendKubeconfigPath) == 0 && !o.UseInClusterConfig && o.RestConfigFunc == nil { errs = append(errs, fmt.Errorf("either --backend-kubeconfig or --use-in-cluster-config must be specified")) } @@ -352,9 +412,9 @@ func (o *Options) Validate() []error { errs = append(errs, fmt.Errorf("--rule-config is required")) } - errs = append(errs, o.SecureServing.Validate()...) - errs = append(errs, o.Authentication.Validate()...) - + if !o.EmbeddedMode { + errs = append(errs, o.SecureServing.Validate()...) + } return errs } diff --git a/pkg/proxy/options_test.go b/pkg/proxy/options_test.go index 0710a13..dab3bf9 100644 --- a/pkg/proxy/options_test.go +++ b/pkg/proxy/options_test.go @@ -25,8 +25,7 @@ import ( func TestKubeConfig(t *testing.T) { defer require.NoError(t, logsv1.ResetForTest(utilfeature.DefaultFeatureGate)) - opts := optionsForTesting(t) - opts.SpiceDBOptions.SpiceDBEndpoint = EmbeddedSpiceDBEndpoint + opts := optionsForTesting(t, WithEmbeddedSpiceDBEndpoint) require.Empty(t, opts.Validate()) c, err := opts.Complete(context.Background()) @@ -46,8 +45,7 @@ func TestKubeConfig(t *testing.T) { func TestInClusterConfig(t *testing.T) { defer require.NoError(t, logsv1.ResetForTest(utilfeature.DefaultFeatureGate)) - opts := optionsForTesting(t) - opts.SpiceDBOptions.SpiceDBEndpoint = EmbeddedSpiceDBEndpoint + opts := optionsForTesting(t, WithEmbeddedSpiceDBEndpoint) opts.BackendKubeconfigPath = "" opts.UseInClusterConfig = true require.Empty(t, opts.Validate()) @@ -62,8 +60,7 @@ func TestInClusterConfig(t *testing.T) { } func TestEmbeddedSpiceDB(t *testing.T) { - opts := optionsForTesting(t) - opts.SpiceDBOptions.SpiceDBEndpoint = EmbeddedSpiceDBEndpoint + opts := optionsForTesting(t, WithEmbeddedSpiceDBEndpoint) require.Empty(t, opts.Validate()) c, err := opts.Complete(context.Background()) @@ -116,8 +113,7 @@ func TestRemoteSpiceDBCerts(t *testing.T) { } func TestRuleConfig(t *testing.T) { - opts := optionsForTesting(t) - opts.SpiceDBOptions.SpiceDBEndpoint = EmbeddedSpiceDBEndpoint + opts := optionsForTesting(t, WithEmbeddedSpiceDBEndpoint) require.Empty(t, opts.Validate()) c, err := opts.Complete(context.Background()) @@ -139,7 +135,7 @@ func TestRuleConfig(t *testing.T) { errConfigBytes := []byte(` apiVersion: authzed.com/v1alpha1 kind: ProxyRule -lock: Pessimistic +lock: Pessimistic match: - apiVersion: authzed.com/v1alpha1 resource: spicedbclusters @@ -151,8 +147,7 @@ prefilter: `) errConfigFile := path.Join(t.TempDir(), "rulesbad.yaml") require.NoError(t, os.WriteFile(errConfigFile, errConfigBytes, 0o600)) - opts = optionsForTesting(t) - opts.SpiceDBOptions.SpiceDBEndpoint = EmbeddedSpiceDBEndpoint + opts = optionsForTesting(t, WithEmbeddedSpiceDBEndpoint) opts.RuleConfigFile = errConfigFile require.Empty(t, opts.Validate()) @@ -160,17 +155,105 @@ prefilter: require.ErrorContains(t, err, "expected") } -func optionsForTesting(t *testing.T) *Options { +func optionsForTesting(t *testing.T, opts ...setOpt) *Options { t.Helper() require.NoError(t, logsv1.ResetForTest(utilfeature.DefaultFeatureGate)) - opts := NewOptions() - opts.SecureServing.BindPort = getFreePort(t, "127.0.0.1") - opts.SecureServing.BindAddress = net.ParseIP("127.0.0.1") - opts.BackendKubeconfigPath = kubeConfigForTest(t) - opts.RuleConfigFile = ruleConfigForTest(t) + options := NewOptions(opts...) + options.SecureServing.BindPort = getFreePort(t, "127.0.0.1") + options.SecureServing.BindAddress = net.ParseIP("127.0.0.1") + options.BackendKubeconfigPath = kubeConfigForTest(t) + options.RuleConfigFile = ruleConfigForTest(t) + require.Empty(t, options.Validate()) + return options +} + +func TestWithEmbeddedProxy(t *testing.T) { + opts := NewOptions(WithEmbeddedProxy) + require.True(t, opts.EmbeddedMode) +} + +func TestWithEmbeddedSpiceDBEndpoint(t *testing.T) { + opts := NewOptions(WithEmbeddedSpiceDBEndpoint) + require.Equal(t, EmbeddedSpiceDBEndpoint, opts.SpiceDBOptions.SpiceDBEndpoint) +} + +func TestWithBothEmbeddedOptions(t *testing.T) { + opts := NewOptions(WithEmbeddedProxy, WithEmbeddedSpiceDBEndpoint) + require.True(t, opts.EmbeddedMode) + require.Equal(t, EmbeddedSpiceDBEndpoint, opts.SpiceDBOptions.SpiceDBEndpoint) +} + +func TestWithEmbeddedProxyOnly(t *testing.T) { + opts := NewOptions(WithEmbeddedProxy) + require.True(t, opts.EmbeddedMode) + require.NotEqual(t, EmbeddedSpiceDBEndpoint, opts.SpiceDBOptions.SpiceDBEndpoint) +} + +func TestWithEmbeddedSpiceDBEndpointOnly(t *testing.T) { + opts := NewOptions(WithEmbeddedSpiceDBEndpoint) + require.False(t, opts.EmbeddedMode) + require.Equal(t, EmbeddedSpiceDBEndpoint, opts.SpiceDBOptions.SpiceDBEndpoint) +} + +func TestWithEmbeddedSpiceDBBootstrap(t *testing.T) { + bootstrapContent := map[string][]byte{ + "bootstrap.yaml": []byte(`schema: |- + definition user {} + definition namespace { + relation creator: user + permission view = creator + } +relationships: | +`), + } + + opts := NewOptions(WithEmbeddedSpiceDBBootstrap(bootstrapContent)) + require.False(t, opts.EmbeddedMode) + require.Equal(t, EmbeddedSpiceDBEndpoint, opts.SpiceDBOptions.SpiceDBEndpoint) + require.Equal(t, bootstrapContent, opts.SpiceDBOptions.BootstrapContent) +} + +func TestWithEmbeddedSpiceDBBootstrapIntegration(t *testing.T) { + defer require.NoError(t, logsv1.ResetForTest(utilfeature.DefaultFeatureGate)) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Create custom bootstrap content + bootstrapContent := map[string][]byte{ + "bootstrap.yaml": []byte(`schema: |- + definition user {} + definition namespace { + relation creator: user + permission view = creator + } + definition lock { + relation workflow: workflow + } + definition workflow {} +relationships: | +`), + } + + opts := optionsForTesting(t, WithEmbeddedSpiceDBBootstrap(bootstrapContent)) require.Empty(t, opts.Validate()) - return opts + + c, err := opts.Complete(ctx) + require.NoError(t, err) + require.NotNil(t, c) + + // Verify that the embedded SpiceDB is created with custom bootstrap + require.NotNil(t, opts.SpiceDBOptions.EmbeddedSpiceDB) + require.NotNil(t, opts.PermissionsClient) + require.NotNil(t, opts.WatchClient) + require.Equal(t, bootstrapContent, opts.SpiceDBOptions.BootstrapContent) +} + +func TestNewOptionsWithoutEmbedded(t *testing.T) { + opts := NewOptions() + require.False(t, opts.EmbeddedMode) + require.NotEqual(t, EmbeddedSpiceDBEndpoint, opts.SpiceDBOptions.SpiceDBEndpoint) } func getFreePort(t *testing.T, listenAddr string) int { @@ -238,10 +321,10 @@ func ruleConfigForTest(t *testing.T) string { configBytes := []byte(` apiVersion: authzed.com/v1alpha1 kind: ProxyRule -lock: Pessimistic +lock: Pessimistic match: - apiVersion: authzed.com/v1alpha1 - resource: spicedbclusters + resource: spicedbclusters verbs: ["list"] prefilter: - fromObjectIDNameExpr: "{{request.name}}" diff --git a/pkg/proxy/server.go b/pkg/proxy/server.go index fcacc63..3dcd9de 100644 --- a/pkg/proxy/server.go +++ b/pkg/proxy/server.go @@ -33,6 +33,7 @@ import ( "github.com/authzed/spicedb-kubeapi-proxy/pkg/authz" "github.com/authzed/spicedb-kubeapi-proxy/pkg/authz/distributedtx" + "github.com/authzed/spicedb-kubeapi-proxy/pkg/inmemory" "github.com/authzed/spicedb-kubeapi-proxy/pkg/rules" ) @@ -76,23 +77,26 @@ func NewServer(ctx context.Context, c *CompletedConfig) (*Server, error) { clusterHost = restConfig.Host klog.FromContext(ctx).WithValues("host", clusterHost).Error(err, "created upstream client") + // Embedded mode setup is done after handler is ready + mux := http.NewServeMux() mux.Handle("/readyz", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = w.Write([]byte("OK")) w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) })) mux.Handle("/livez", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = w.Write([]byte("OK")) w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) })) clusterProxy := &httputil.ReverseProxy{ ErrorLog: nil, // TODO FlushInterval: -1, Director: func(req *http.Request) { - req.URL.Host = strings.TrimPrefix(clusterHost, "https://") + host := strings.TrimPrefix(clusterHost, "https://") + req.URL.Host = strings.TrimSuffix(host, "/") req.URL.Scheme = "https" }, ModifyResponse: func(response *http.Response) error { @@ -174,14 +178,18 @@ func (s *Server) Run(ctx context.Context) error { return s.WorkflowWorker.Start(ctx) }) - g.Go(func() error { - done, _, err := s.opts.ServingInfo.Serve(s.Handler, time.Second*60, ctx.Done()) - if err != nil { - return err - } - <-done - return nil - }) + if !s.opts.EmbeddedMode { + // For regular mode, use TLS serving + g.Go(func() error { + done, _, err := s.opts.ServingInfo.Serve(s.Handler, time.Second*60, ctx.Done()) + if err != nil { + return err + } + <-done + return nil + }) + } + // For embedded mode, connections are handled on-demand via GetEmbeddedClient() if err := g.Wait(); err != nil { ctx, cancel = context.WithTimeout(context.Background(), 1*time.Minute) @@ -258,3 +266,126 @@ func computeDiscoverCacheDir(parentDir, host string) string { // overlyCautiousIllegalFileCharacters matches characters that *might* not be supported. Windows is really restrictive, so this is really restrictive var overlyCautiousIllegalFileCharacters = regexp.MustCompile(`[^(\w/.)]`) + +// EmbeddedClientOption configures the embedded client +type EmbeddedClientOption func(*embeddedClientConfig) + +// embeddedClientConfig holds configuration for embedded client +type embeddedClientConfig struct { + username string + groups []string + extra map[string]string +} + +// WithUser sets the username for the embedded client +func WithUser(username string) EmbeddedClientOption { + return func(config *embeddedClientConfig) { + config.username = username + } +} + +// WithGroups sets the groups for the embedded client +func WithGroups(groups ...string) EmbeddedClientOption { + return func(config *embeddedClientConfig) { + config.groups = groups + } +} + +// WithExtra sets extra attributes for the embedded client +func WithExtra(key, value string) EmbeddedClientOption { + return func(config *embeddedClientConfig) { + if config.extra == nil { + config.extra = make(map[string]string) + } + config.extra[key] = value + } +} + +// GetEmbeddedClient returns an HTTP client that connects directly to the handler +func (s *Server) GetEmbeddedClient(opts ...EmbeddedClientOption) *http.Client { + if !s.opts.EmbeddedMode || s.Handler == nil { + return nil + } + + // Create base client + client := inmemory.NewClient(s.Handler) + + // If no options provided, return basic client + if len(opts) == 0 { + return client + } + + // Apply options to configuration + config := &embeddedClientConfig{} + for _, opt := range opts { + opt(config) + } + + // Get configured header names from embedded authentication + usernameHeaders := s.opts.Authentication.Embedded.UsernameHeaders + if len(usernameHeaders) == 0 { + usernameHeaders = []string{"X-Remote-User"} + } + + groupHeaders := s.opts.Authentication.Embedded.GroupHeaders + if len(groupHeaders) == 0 { + groupHeaders = []string{"X-Remote-Group"} + } + + extraHeaderPrefixes := s.opts.Authentication.Embedded.ExtraHeaderPrefixes + if len(extraHeaderPrefixes) == 0 { + extraHeaderPrefixes = []string{"X-Remote-Extra-"} + } + + // Wrap the transport to add authentication headers automatically + client.Transport = &authHeaderTransport{ + base: client.Transport, + username: config.username, + groups: config.groups, + extra: config.extra, + usernameHeaders: usernameHeaders, + groupHeaders: groupHeaders, + extraHeaderPrefixes: extraHeaderPrefixes, + } + + return client +} + +// authHeaderTransport automatically adds authentication headers based on configuration +type authHeaderTransport struct { + base http.RoundTripper + username string + groups []string + extra map[string]string + usernameHeaders []string + groupHeaders []string + extraHeaderPrefixes []string +} + +func (t *authHeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone request to avoid modifying original + newReq := req.Clone(req.Context()) + + // Add username header (use first configured header) + if t.username != "" && len(t.usernameHeaders) > 0 { + newReq.Header.Set(t.usernameHeaders[0], t.username) + } + + // Add group headers (use first configured header for all groups) + if len(t.groups) > 0 && len(t.groupHeaders) > 0 { + for _, group := range t.groups { + newReq.Header.Add(t.groupHeaders[0], group) + } + } + + // Add extra headers (use first configured prefix) + if len(t.extra) > 0 && len(t.extraHeaderPrefixes) > 0 { + prefix := t.extraHeaderPrefixes[0] + for key, value := range t.extra { + headerName := prefix + strings.ToLower(key) + newReq.Header.Set(headerName, value) + } + } + + return t.base.RoundTrip(newReq) +} diff --git a/pkg/spicedb/spicedb.go b/pkg/spicedb/spicedb.go index bec33ab..309fcf9 100644 --- a/pkg/spicedb/spicedb.go +++ b/pkg/spicedb/spicedb.go @@ -14,9 +14,11 @@ import ( //go:embed bootstrap.yaml var bootstrap []byte -func NewServer(ctx context.Context, bootstrapFilePath string) (server.RunnableServer, error) { +func NewServer(ctx context.Context, bootstrapFilePath string, bootstrapContent map[string][]byte) (server.RunnableServer, error) { bootstrapOption := datastore.SetBootstrapFileContents(map[string][]byte{"schema": bootstrap}) - if bootstrapFilePath != "" { + if len(bootstrapContent) > 0 { + bootstrapOption = datastore.SetBootstrapFileContents(bootstrapContent) + } else if len(bootstrapFilePath) > 0 { bootstrapOption = datastore.SetBootstrapFiles([]string{bootstrapFilePath}) } return server.NewConfigWithOptionsAndDefaults(server.WithGRPCServer(util.GRPCServerConfig{