From 2c656b4cabf165b5b1e7b46f36cb923d5fa581bd Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 7 May 2025 10:58:53 +0200 Subject: [PATCH] tlv: fix SizeFunc signature There are cases where the SizeFunc call-back can produce an error. We should catch and handle this error else it leads to code that could silently produce errors. --- tlv/fuzz_test.go | 12 ++++++------ tlv/primitive.go | 17 ++++++++++------- tlv/record.go | 12 ++++++------ tlv/record_test.go | 14 ++++++++------ tlv/stream.go | 9 ++++++++- tlv/tlv_test.go | 12 ++++++------ 6 files changed, 44 insertions(+), 32 deletions(-) diff --git a/tlv/fuzz_test.go b/tlv/fuzz_test.go index 0a38aa66eab..824ea113d5d 100644 --- a/tlv/fuzz_test.go +++ b/tlv/fuzz_test.go @@ -234,14 +234,14 @@ func FuzzStream(f *testing.F) { boolean bool ) - sizeTU16 := func() uint64 { - return SizeTUint16(tu16) + sizeTU16 := func() (uint64, error) { + return SizeTUint16(tu16), nil } - sizeTU32 := func() uint64 { - return SizeTUint32(tu32) + sizeTU32 := func() (uint64, error) { + return SizeTUint32(tu32), nil } - sizeTU64 := func() uint64 { - return SizeTUint64(tu64) + sizeTU64 := func() (uint64, error) { + return SizeTUint64(tu64), nil } // We deliberately set each record's type number to its index in diff --git a/tlv/primitive.go b/tlv/primitive.go index d241d273d0f..626552d5cf7 100644 --- a/tlv/primitive.go +++ b/tlv/primitive.go @@ -379,15 +379,18 @@ func DBigSize(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { func SizeBigSize(val interface{}) SizeFunc { var size uint64 - if i, ok := val.(*uint32); ok { - size = VarIntSize(uint64(*i)) - } - - if i, ok := val.(*uint64); ok { + switch i := val.(type) { + case *uint32: size = VarIntSize(uint64(*i)) + case *uint64: + size = VarIntSize(*i) + default: + return func() (uint64, error) { + return 0, fmt.Errorf("invalid type %T for BigSize", val) + } } - return func() uint64 { - return size + return func() (uint64, error) { + return size, nil } } diff --git a/tlv/record.go b/tlv/record.go index 7813ee771e3..86c8fcb9e36 100644 --- a/tlv/record.go +++ b/tlv/record.go @@ -40,16 +40,16 @@ func DNOP(io.Reader, interface{}, *[8]byte, uint64) error { return nil } // SizeFunc is a function that can compute the length of a given field. Since // the size of the underlying field can change, this allows the size of the // field to be evaluated at the time of encoding. -type SizeFunc func() uint64 +type SizeFunc func() (uint64, error) // SizeVarBytes returns a SizeFunc that can compute the length of a byte slice. func SizeVarBytes(e *[]byte) SizeFunc { - return func() uint64 { - return uint64(len(*e)) + return func() (uint64, error) { + return uint64(len(*e)), nil } } -// RecorderProducer is an interface for objects that can produce a Record object +// RecordProducer is an interface for objects that can produce a Record object // capable of encoding and/or decoding the RecordProducer as a Record. type RecordProducer interface { // Record returns a Record that can be used to encode or decode the @@ -78,9 +78,9 @@ func (f *Record) Record() Record { // Size returns the size of the Record's value. If no static size is known, the // dynamic size will be evaluated. -func (f *Record) Size() uint64 { +func (f *Record) Size() (uint64, error) { if f.sizeFunc == nil { - return f.staticSize + return f.staticSize, nil } return f.sizeFunc() diff --git a/tlv/record_test.go b/tlv/record_test.go index d1f14501503..37ce064d1cd 100644 --- a/tlv/record_test.go +++ b/tlv/record_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/davecgh/go-spew/spew" + "github.com/stretchr/testify/require" ) // TestSortRecords tests that SortRecords is able to properly sort records in @@ -135,12 +136,13 @@ func TestRecordMapTransformation(t *testing.T) { i, tlvBytes, b.Bytes()) } - if unmappedRecords[i].Size() != testCase.records[0].Size() { - t.Fatalf("#%v: wrong size: expected %v, "+ - "got %v", i, - unmappedRecords[i].Size(), - testCase.records[i].Size()) - } + unmappedSize, err := unmappedRecords[i].Size() + require.NoError(t, err) + + testCaseSize, err := testCase.records[0].Size() + require.NoError(t, err) + + require.Equal(t, unmappedSize, testCaseSize) } } } diff --git a/tlv/stream.go b/tlv/stream.go index 70fade22d79..5ef3e0637e0 100644 --- a/tlv/stream.go +++ b/tlv/stream.go @@ -3,6 +3,7 @@ package tlv import ( "bytes" "errors" + "fmt" "io" "math" ) @@ -94,8 +95,14 @@ func (s *Stream) Encode(w io.Writer) error { return err } + size, err := rec.Size() + if err != nil { + return fmt.Errorf("could not determine record size: %w", + err) + } + // Write the record's length as a varint. - err = WriteVarInt(w, rec.Size(), &s.buf) + err = WriteVarInt(w, size, &s.buf) if err != nil { return err } diff --git a/tlv/tlv_test.go b/tlv/tlv_test.go index 35e23cab72a..4b8d7499518 100644 --- a/tlv/tlv_test.go +++ b/tlv/tlv_test.go @@ -54,8 +54,8 @@ type N1 struct { stream *tlv.Stream } -func (n *N1) sizeAmt() uint64 { - return tlv.SizeTUint64(n.amt) +func (n *N1) sizeAmt() (uint64, error) { + return tlv.SizeTUint64(n.amt), nil } func NewN1() *N1 { @@ -89,12 +89,12 @@ type N2 struct { stream *tlv.Stream } -func (n *N2) sizeAmt() uint64 { - return tlv.SizeTUint64(n.amt) +func (n *N2) sizeAmt() (uint64, error) { + return tlv.SizeTUint64(n.amt), nil } -func (n *N2) sizeCltv() uint64 { - return tlv.SizeTUint32(n.cltvExpiry) +func (n *N2) sizeCltv() (uint64, error) { + return tlv.SizeTUint32(n.cltvExpiry), nil } func NewN2() *N2 {