Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ require (
github.com/quay/goval-parser v0.8.8
github.com/remind101/migrate v0.0.0-20170729031349-52c1edff7319
github.com/spdx/tools-golang v0.5.6
github.com/tetratelabs/wazero v1.11.0
github.com/ulikunitz/xz v0.5.15
go.opentelemetry.io/otel v1.39.0
go.opentelemetry.io/otel/trace v1.39.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tetratelabs/wazero v1.11.0 h1:+gKemEuKCTevU4d7ZTzlsvgd1uaToIDtlQlmNbwqYhA=
github.com/tetratelabs/wazero v1.11.0/go.mod h1:eV28rsN8Q+xwjogd7f4/Pp4xFxO7uOGbLcD/LzB1wiU=
github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
Expand Down
212 changes: 212 additions & 0 deletions internal/matcher/wasm/host.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
package wasm

import (
"context"
"fmt"
"reflect"
"slices"
"strings"
"sync"
"unsafe"

"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"

"github.com/quay/claircore"
)

// PtrMember is a helper to take a pointer to a Go struct, then return a
// pointer that's contained as a field.
func ptrMember(off uintptr) api.GoModuleFunc {
return func(ctx context.Context, mod api.Module, stack []uint64) {
// Take in *A, which has a *B at offset "off".
ref := unsafe.Pointer(api.DecodeExternref(stack[0])) // Shouldn't be nil.
ptrField := unsafe.Add(ref, off) // This pointer can't be nil.
ptr := *(*unsafe.Pointer)(ptrField) // Can be nil.
stack[0] = api.EncodeExternref(uintptr(ptr))
}
}

// PtrToMember is a helper to take a pointer to a Go struct, then return a
// pointer to a contained field.
func ptrToMember(off uintptr) api.GoModuleFunc {
return func(ctx context.Context, mod api.Module, stack []uint64) {
// Take in *A, which has a B at offset "off".
ref := unsafe.Pointer(api.DecodeExternref(stack[0])) // Shouldn't be nil.
ptr := unsafe.Add(ref, off) // This pointer can't be nil.
stack[0] = api.EncodeExternref(uintptr(ptr))
}
}

// StringMember is a helper to take a pointer to a Go struct, then return a
// copy of a string member to a caller-allocated buffer.
func stringMember(off uintptr) api.GoModuleFunc {
return func(ctx context.Context, mod api.Module, stack []uint64) {
// Unsure of another way to get this length information.
h := (*reflect.StringHeader)(unsafe.Add(unsafe.Pointer(api.DecodeExternref(stack[0])), off))
offset := api.DecodeU32(stack[1])
lim := int(api.DecodeU32(stack[2]))
s := unsafe.String((*byte)(unsafe.Pointer(h.Data)), h.Len)
sz := min(lim, len(s))
if sz == 0 {
stack[0] = api.EncodeI32(0)
return
}
s = s[:sz]
mem := mod.ExportedMemory("memory")
if mem.WriteString(offset, s) {
stack[0] = api.EncodeI32(int32(sz))
} else {
stack[0] = api.EncodeI32(0)
}
}
}

// StringerMember is a helper to take a pointer to a Go struct, then place the
// string representation of a member into a caller-allocated buffer.
func stringerMember(off uintptr) api.GoModuleFunc {
return func(ctx context.Context, mod api.Module, stack []uint64) {
iface := (any)(unsafe.Pointer(api.DecodeExternref(stack[0]) + off)).(fmt.Stringer)
offset := api.DecodeU32(stack[1])
lim := int(api.DecodeU32(stack[2]))
s := iface.String()
sz := min(lim, len(s))
if mod.ExportedMemory("memory").WriteString(offset, s[:sz]) {
stack[0] = api.EncodeI32(int32(sz))
} else {
stack[0] = api.EncodeI32(0)
}
}
}

// NotNil checks that the passed externref is not-nil.
//
// This is needed because externrefs are unobservable from within WASM; they
// can only be handed back to the host and not manipulated in any way.
func notNil(ctx context.Context, mod api.Module, stack []uint64) {
if api.DecodeExternref(stack[0]) != 0 {
stack[0] = api.EncodeI32(1)
} else {
stack[0] = api.EncodeI32(0)
}
}

type methodSpec struct {
Name string
Func api.GoModuleFunc
Params []api.ValueType
ParamNames []string
Results []api.ValueType
ResultNames []string
}

func (s *methodSpec) Build(b wazero.HostModuleBuilder) wazero.HostModuleBuilder {
return b.NewFunctionBuilder().
WithName(s.Name).
WithParameterNames(s.ParamNames...).
WithResultNames(s.ResultNames...).
WithGoModuleFunction(s.Func, s.Params, s.Results).
Export(s.Name)
}

func gettersFor[T any]() []methodSpec {
t := reflect.TypeFor[T]()
recv := strings.ToLower(t.Name())
out := make([]methodSpec, 0, t.NumField())

switch t {
// These types are passed-in and always valid.
case reflect.TypeFor[claircore.IndexRecord](),
reflect.TypeFor[claircore.Vulnerability]():
default:
out = append(out, methodSpec{
Name: fmt.Sprintf("%s_valid", recv),
Func: notNil,
Params: []api.ValueType{api.ValueTypeExternref},
Results: []api.ValueType{api.ValueTypeI32},
ParamNames: []string{recv + "Ref"},
ResultNames: []string{"ok"},
})
}
for i := 0; i < t.NumField(); i++ {
sf := t.Field(i)
if !sf.IsExported() {
continue
}
if sf.Name == "ID" { // Skip "id" fields.
continue
}

ft := sf.Type
tgt := strings.ToLower(sf.Name)
// Do some fixups:
switch tgt {
case "dist":
tgt = "distribution"
case "arch":
tgt = "architecture"
case "repo":
tgt = "repository"
}
mi := len(out)
out = append(out, methodSpec{})
m := &out[mi]
m.Name = fmt.Sprintf("%s_get_%s", recv, tgt)
switch ft.Kind() {
case reflect.Pointer:
m.Func = ptrMember(sf.Offset)
m.Params = []api.ValueType{api.ValueTypeExternref}
m.Results = []api.ValueType{api.ValueTypeExternref}
m.ParamNames = []string{recv + "Ref"}
m.ResultNames = []string{tgt + "Ref"}
case reflect.String:
m.Func = stringMember(sf.Offset)
m.Params = []api.ValueType{api.ValueTypeExternref, api.ValueTypeI32, api.ValueTypeI32}
m.Results = []api.ValueType{api.ValueTypeI32}
m.ParamNames = []string{recv + "Ref", "buf", "buf_len"}
m.ResultNames = []string{"len"}
case reflect.Struct:
switch {
case ft == reflect.TypeFor[claircore.Version]():
m.Func = ptrToMember(sf.Offset)
m.Params = []api.ValueType{api.ValueTypeExternref}
m.Results = []api.ValueType{api.ValueTypeExternref}
m.ParamNames = []string{recv + "Ref"}
m.ResultNames = []string{tgt + "Ref"}
case ft.Implements(reflect.TypeFor[fmt.Stringer]()):
m.Func = stringerMember(sf.Offset)
m.Params = []api.ValueType{api.ValueTypeExternref, api.ValueTypeI32, api.ValueTypeI32}
m.Results = []api.ValueType{api.ValueTypeI32}
m.ParamNames = []string{recv + "Ref", "buf", "buf_len"}
m.ResultNames = []string{"len"}
default:
out = out[:mi]
}
default:
out = out[:mi]
}
}

return slices.Clip(out)
}

var hostV1Interface = sync.OnceValue(func() []methodSpec {
return slices.Concat(
gettersFor[claircore.IndexRecord](),
gettersFor[claircore.Detector](),
gettersFor[claircore.Distribution](),
gettersFor[claircore.Package](),
gettersFor[claircore.Range](),
gettersFor[claircore.Repository](),
gettersFor[claircore.Version](),
gettersFor[claircore.Vulnerability](),
)
})

func buildHostV1Interface(rt wazero.Runtime) wazero.HostModuleBuilder {
b := rt.NewHostModuleBuilder("claircore_matcher_1")
for _, spec := range hostV1Interface() {
b = spec.Build(b)
}
return b
}
111 changes: 111 additions & 0 deletions internal/matcher/wasm/host_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package wasm

import (
"maps"
"os"
"slices"
"strings"
"sync"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"

"github.com/quay/claircore"
"github.com/quay/claircore/libvuln/driver"
)

func init() {
// Override the disk caching for tests.
cache = sync.OnceValue(wazero.NewCompilationCache)
}

func TestHostV1(t *testing.T) {
ctx := t.Context()
rConfig := runtimeConfig()
rt := wazero.NewRuntimeWithConfig(ctx, rConfig)
mod, err := buildHostV1Interface(rt).Compile(ctx)
if err != nil {
t.Fatal(err)
}
fns := mod.ExportedFunctions()
keys := slices.Collect(maps.Keys(fns))
slices.Sort(keys)
var b strings.Builder

writelist := func(ts []api.ValueType, ns []string) {
b.WriteByte('(')
for i := range ts {
if i != 0 {
b.WriteString(", ")
}
b.WriteString(ns[i])
b.WriteString(": ")
switch ts[i] {
case api.ValueTypeExternref:
b.WriteString("externref")
case api.ValueTypeI32:
b.WriteString("i32")
case api.ValueTypeI64:
b.WriteString("i64")
case api.ValueTypeF32:
b.WriteString("f32")
case api.ValueTypeF64:
b.WriteString("f64")
default:
b.WriteString("???")
}
}
b.WriteByte(')')
}
for _, k := range keys {
v := fns[k]
b.Reset()
b.WriteString(v.DebugName())
writelist(v.ParamTypes(), v.ParamNames())
b.WriteString(" -> ")
writelist(v.ResultTypes(), v.ResultNames())

t.Log(b.String())
}
}

func TestTrivial(t *testing.T) {
ctx := t.Context()
f, err := os.Open("testdata/trivial.wasm")
if err != nil {
t.Fatal(err)
}
defer f.Close()

m, err := NewMatcher(ctx, "trivial", f)
if err != nil {
t.Fatal(err)
}

t.Run("Query", func(t *testing.T) {
want := []driver.MatchConstraint{driver.PackageName, driver.HasFixedInVersion}
got := m.Query()
if !cmp.Equal(got, want) {
t.Error(cmp.Diff(got, want))
}
})

t.Log(`testing trvial matcher: "Filter() == true" when "len(IndexRecord.Package.Name) != 0"`)
r := &claircore.IndexRecord{
Package: &claircore.Package{Name: "pkg"},
}
ok := m.Filter(r)
t.Logf("package name %q: %v", r.Package.Name, ok)
if !ok {
t.Fail()
}

r.Package = new(claircore.Package)
ok = m.Filter(r)
t.Logf("package name %q: %v", r.Package.Name, ok)
if ok {
t.Fail()
}
}
Loading