Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tlv/fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,14 @@ func FuzzStream(f *testing.F) {
boolean bool
)

sizeTU16 := func() uint64 {
return SizeTUint16(tu16)
sizeTU16 := func() (uint64, error) {
return SizeTUint16(tu16), nil
}
sizeTU32 := func() uint64 {
return SizeTUint32(tu32)
sizeTU32 := func() (uint64, error) {
return SizeTUint32(tu32), nil
}
sizeTU64 := func() uint64 {
return SizeTUint64(tu64)
sizeTU64 := func() (uint64, error) {
return SizeTUint64(tu64), nil
}

// We deliberately set each record's type number to its index in
Expand Down
17 changes: 10 additions & 7 deletions tlv/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,15 +379,18 @@ func DBigSize(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
func SizeBigSize(val interface{}) SizeFunc {
var size uint64

if i, ok := val.(*uint32); ok {
size = VarIntSize(uint64(*i))
}

if i, ok := val.(*uint64); ok {
switch i := val.(type) {
case *uint32:
size = VarIntSize(uint64(*i))
case *uint64:
size = VarIntSize(*i)
default:
return func() (uint64, error) {
return 0, fmt.Errorf("invalid type %T for BigSize", val)
}
}

return func() uint64 {
return size
return func() (uint64, error) {
return size, nil
}
}
12 changes: 6 additions & 6 deletions tlv/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ func DNOP(io.Reader, interface{}, *[8]byte, uint64) error { return nil }
// SizeFunc is a function that can compute the length of a given field. Since
// the size of the underlying field can change, this allows the size of the
// field to be evaluated at the time of encoding.
type SizeFunc func() uint64
type SizeFunc func() (uint64, error)

// SizeVarBytes returns a SizeFunc that can compute the length of a byte slice.
func SizeVarBytes(e *[]byte) SizeFunc {
return func() uint64 {
return uint64(len(*e))
return func() (uint64, error) {
return uint64(len(*e)), nil
}
}

// RecorderProducer is an interface for objects that can produce a Record object
// RecordProducer is an interface for objects that can produce a Record object
// capable of encoding and/or decoding the RecordProducer as a Record.
type RecordProducer interface {
// Record returns a Record that can be used to encode or decode the
Expand Down Expand Up @@ -78,9 +78,9 @@ func (f *Record) Record() Record {

// Size returns the size of the Record's value. If no static size is known, the
// dynamic size will be evaluated.
func (f *Record) Size() uint64 {
func (f *Record) Size() (uint64, error) {
if f.sizeFunc == nil {
return f.staticSize
return f.staticSize, nil
}

return f.sizeFunc()
Expand Down
14 changes: 8 additions & 6 deletions tlv/record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"

"github.com/davecgh/go-spew/spew"
"github.com/stretchr/testify/require"
)

// TestSortRecords tests that SortRecords is able to properly sort records in
Expand Down Expand Up @@ -135,12 +136,13 @@ func TestRecordMapTransformation(t *testing.T) {
i, tlvBytes, b.Bytes())
}

if unmappedRecords[i].Size() != testCase.records[0].Size() {
t.Fatalf("#%v: wrong size: expected %v, "+
"got %v", i,
unmappedRecords[i].Size(),
testCase.records[i].Size())
}
unmappedSize, err := unmappedRecords[i].Size()
require.NoError(t, err)

testCaseSize, err := testCase.records[0].Size()
require.NoError(t, err)

require.Equal(t, unmappedSize, testCaseSize)
}
}
}
9 changes: 8 additions & 1 deletion tlv/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tlv
import (
"bytes"
"errors"
"fmt"
"io"
"math"
)
Expand Down Expand Up @@ -94,8 +95,14 @@ func (s *Stream) Encode(w io.Writer) error {
return err
}

size, err := rec.Size()
if err != nil {
return fmt.Errorf("could not determine record size: %w",
err)
}

// Write the record's length as a varint.
err = WriteVarInt(w, rec.Size(), &s.buf)
err = WriteVarInt(w, size, &s.buf)
if err != nil {
return err
}
Expand Down
12 changes: 6 additions & 6 deletions tlv/tlv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ type N1 struct {
stream *tlv.Stream
}

func (n *N1) sizeAmt() uint64 {
return tlv.SizeTUint64(n.amt)
func (n *N1) sizeAmt() (uint64, error) {
return tlv.SizeTUint64(n.amt), nil
}

func NewN1() *N1 {
Expand Down Expand Up @@ -89,12 +89,12 @@ type N2 struct {
stream *tlv.Stream
}

func (n *N2) sizeAmt() uint64 {
return tlv.SizeTUint64(n.amt)
func (n *N2) sizeAmt() (uint64, error) {
return tlv.SizeTUint64(n.amt), nil
}

func (n *N2) sizeCltv() uint64 {
return tlv.SizeTUint32(n.cltvExpiry)
func (n *N2) sizeCltv() (uint64, error) {
return tlv.SizeTUint32(n.cltvExpiry), nil
}

func NewN2() *N2 {
Expand Down
Loading