/ tlv / primitive.go
primitive.go
  1  package tlv
  2  
  3  import (
  4  	"encoding/binary"
  5  	"errors"
  6  	"fmt"
  7  	"io"
  8  
  9  	"github.com/btcsuite/btcd/btcec/v2"
 10  )
 11  
 12  // ErrTypeForEncoding signals that an incorrect type was passed to an Encoder.
 13  type ErrTypeForEncoding struct {
 14  	val     interface{}
 15  	expType string
 16  }
 17  
 18  // NewTypeForEncodingErr creates a new ErrTypeForEncoding given the incorrect
 19  // val and the expected type.
 20  func NewTypeForEncodingErr(val interface{}, expType string) ErrTypeForEncoding {
 21  	return ErrTypeForEncoding{
 22  		val:     val,
 23  		expType: expType,
 24  	}
 25  }
 26  
 27  // Error returns a human-readable description of the type mismatch.
 28  func (e ErrTypeForEncoding) Error() string {
 29  	return fmt.Sprintf("ErrTypeForEncoding want (type: *%s), "+
 30  		"got (type: %T)", e.expType, e.val)
 31  }
 32  
 33  // ErrTypeForDecoding signals that an incorrect type was passed to a Decoder or
 34  // that the expected length of the encoding is different from that required by
 35  // the expected type.
 36  type ErrTypeForDecoding struct {
 37  	val       interface{}
 38  	expType   string
 39  	valLength uint64
 40  	expLength uint64
 41  }
 42  
 43  // NewTypeForDecodingErr creates a new ErrTypeForDecoding given the incorrect
 44  // val and expected type, or the mismatch in their expected lengths.
 45  func NewTypeForDecodingErr(val interface{}, expType string,
 46  	valLength, expLength uint64) ErrTypeForDecoding {
 47  
 48  	return ErrTypeForDecoding{
 49  		val:       val,
 50  		expType:   expType,
 51  		valLength: valLength,
 52  		expLength: expLength,
 53  	}
 54  }
 55  
 56  // Error returns a human-readable description of the type mismatch.
 57  func (e ErrTypeForDecoding) Error() string {
 58  	return fmt.Sprintf("ErrTypeForDecoding want (type: *%s, length: %v), "+
 59  		"got (type: %T, length: %v)", e.expType, e.expLength, e.val,
 60  		e.valLength)
 61  }
 62  
 63  var (
 64  	byteOrder = binary.BigEndian
 65  )
 66  
 67  // EUint8 is an Encoder for uint8 values. An error is returned if val is not a
 68  // *uint8.
 69  func EUint8(w io.Writer, val interface{}, buf *[8]byte) error {
 70  	if i, ok := val.(*uint8); ok {
 71  		return EUint8T(w, *i, buf)
 72  	}
 73  	return NewTypeForEncodingErr(val, "uint8")
 74  }
 75  
 76  // EUint8T encodes a uint8 val to the provided io.Writer. This method is exposed
 77  // so that encodings for custom uint8-like types can be created without
 78  // incurring an extra heap allocation.
 79  func EUint8T(w io.Writer, val uint8, buf *[8]byte) error {
 80  	buf[0] = val
 81  	_, err := w.Write(buf[:1])
 82  	return err
 83  }
 84  
 85  // EUint16 is an Encoder for uint16 values. An error is returned if val is not a
 86  // *uint16.
 87  func EUint16(w io.Writer, val interface{}, buf *[8]byte) error {
 88  	if i, ok := val.(*uint16); ok {
 89  		return EUint16T(w, *i, buf)
 90  	}
 91  	return NewTypeForEncodingErr(val, "uint16")
 92  }
 93  
 94  // EUint16T encodes a uint16 val to the provided io.Writer. This method is
 95  // exposed so that encodings for custom uint16-like types can be created without
 96  // incurring an extra heap allocation.
 97  func EUint16T(w io.Writer, val uint16, buf *[8]byte) error {
 98  	byteOrder.PutUint16(buf[:2], val)
 99  	_, err := w.Write(buf[:2])
100  	return err
101  }
102  
103  // EUint32 is an Encoder for uint32 values. An error is returned if val is not a
104  // *uint32.
105  func EUint32(w io.Writer, val interface{}, buf *[8]byte) error {
106  	if i, ok := val.(*uint32); ok {
107  		return EUint32T(w, *i, buf)
108  	}
109  	return NewTypeForEncodingErr(val, "uint32")
110  }
111  
112  // EUint32T encodes a uint32 val to the provided io.Writer. This method is
113  // exposed so that encodings for custom uint32-like types can be created without
114  // incurring an extra heap allocation.
115  func EUint32T(w io.Writer, val uint32, buf *[8]byte) error {
116  	byteOrder.PutUint32(buf[:4], val)
117  	_, err := w.Write(buf[:4])
118  	return err
119  }
120  
121  // EUint64 is an Encoder for uint64 values. An error is returned if val is not a
122  // *uint64.
123  func EUint64(w io.Writer, val interface{}, buf *[8]byte) error {
124  	if i, ok := val.(*uint64); ok {
125  		return EUint64T(w, *i, buf)
126  	}
127  	return NewTypeForEncodingErr(val, "uint64")
128  }
129  
130  // EUint64T encodes a uint64 val to the provided io.Writer. This method is
131  // exposed so that encodings for custom uint64-like types can be created without
132  // incurring an extra heap allocation.
133  func EUint64T(w io.Writer, val uint64, buf *[8]byte) error {
134  	byteOrder.PutUint64(buf[:], val)
135  	_, err := w.Write(buf[:])
136  	return err
137  }
138  
139  // EBool encodes a boolean. An error is returned if val is not a boolean.
140  func EBool(w io.Writer, val interface{}, buf *[8]byte) error {
141  	if i, ok := val.(*bool); ok {
142  		return EBoolT(w, *i, buf)
143  	}
144  	return NewTypeForEncodingErr(val, "bool")
145  }
146  
147  // EBoolT encodes a bool val to the provided io.Writer. This method is exposed
148  // so that encodings for custom bool-like types can be created without
149  // incurring an extra heap allocation.
150  func EBoolT(w io.Writer, val bool, buf *[8]byte) error {
151  	if val {
152  		buf[0] = 1
153  	} else {
154  		buf[0] = 0
155  	}
156  	_, err := w.Write(buf[:1])
157  	return err
158  }
159  
160  // DUint8 is a Decoder for uint8 values. An error is returned if val is not a
161  // *uint8.
162  func DUint8(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
163  	if i, ok := val.(*uint8); ok && l == 1 {
164  		if _, err := io.ReadFull(r, buf[:1]); err != nil {
165  			return err
166  		}
167  		*i = buf[0]
168  		return nil
169  	}
170  	return NewTypeForDecodingErr(val, "uint8", l, 1)
171  }
172  
173  // DUint16 is a Decoder for uint16 values. An error is returned if val is not a
174  // *uint16.
175  func DUint16(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
176  	if i, ok := val.(*uint16); ok && l == 2 {
177  		if _, err := io.ReadFull(r, buf[:2]); err != nil {
178  			return err
179  		}
180  		*i = byteOrder.Uint16(buf[:2])
181  		return nil
182  	}
183  	return NewTypeForDecodingErr(val, "uint16", l, 2)
184  }
185  
186  // DUint32 is a Decoder for uint32 values. An error is returned if val is not a
187  // *uint32.
188  func DUint32(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
189  	if i, ok := val.(*uint32); ok && l == 4 {
190  		if _, err := io.ReadFull(r, buf[:4]); err != nil {
191  			return err
192  		}
193  		*i = byteOrder.Uint32(buf[:4])
194  		return nil
195  	}
196  	return NewTypeForDecodingErr(val, "uint32", l, 4)
197  }
198  
199  // DUint64 is a Decoder for uint64 values. An error is returned if val is not a
200  // *uint64.
201  func DUint64(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
202  	if i, ok := val.(*uint64); ok && l == 8 {
203  		if _, err := io.ReadFull(r, buf[:]); err != nil {
204  			return err
205  		}
206  		*i = byteOrder.Uint64(buf[:])
207  		return nil
208  	}
209  	return NewTypeForDecodingErr(val, "uint64", l, 8)
210  }
211  
212  // DBool decodes a boolean. An error is returned if val is not a boolean.
213  func DBool(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
214  	if i, ok := val.(*bool); ok && l == 1 {
215  		if _, err := io.ReadFull(r, buf[:1]); err != nil {
216  			return err
217  		}
218  		if buf[0] != 0 && buf[0] != 1 {
219  			return errors.New("corrupted data")
220  		}
221  		*i = buf[0] != 0
222  		return nil
223  	}
224  	return NewTypeForDecodingErr(val, "bool", l, 1)
225  }
226  
227  // EBytes32 is an Encoder for 32-byte arrays. An error is returned if val is not
228  // a *[32]byte.
229  func EBytes32(w io.Writer, val interface{}, _ *[8]byte) error {
230  	if b, ok := val.(*[32]byte); ok {
231  		_, err := w.Write(b[:])
232  		return err
233  	}
234  	return NewTypeForEncodingErr(val, "[32]byte")
235  }
236  
237  // DBytes32 is a Decoder for 32-byte arrays. An error is returned if val is not
238  // a *[32]byte.
239  func DBytes32(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
240  	if b, ok := val.(*[32]byte); ok && l == 32 {
241  		_, err := io.ReadFull(r, b[:])
242  		return err
243  	}
244  	return NewTypeForDecodingErr(val, "[32]byte", l, 32)
245  }
246  
247  // EBytes33 is an Encoder for 33-byte arrays. An error is returned if val is not
248  // a *[33]byte.
249  func EBytes33(w io.Writer, val interface{}, _ *[8]byte) error {
250  	if b, ok := val.(*[33]byte); ok {
251  		_, err := w.Write(b[:])
252  		return err
253  	}
254  	return NewTypeForEncodingErr(val, "[33]byte")
255  }
256  
257  // DBytes33 is a Decoder for 33-byte arrays. An error is returned if val is not
258  // a *[33]byte.
259  func DBytes33(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
260  	if b, ok := val.(*[33]byte); ok && l == 33 {
261  		_, err := io.ReadFull(r, b[:])
262  		return err
263  	}
264  	return NewTypeForDecodingErr(val, "[33]byte", l, 33)
265  }
266  
267  // EBytes64 is an Encoder for 64-byte arrays. An error is returned if val is not
268  // a *[64]byte.
269  func EBytes64(w io.Writer, val interface{}, _ *[8]byte) error {
270  	if b, ok := val.(*[64]byte); ok {
271  		_, err := w.Write(b[:])
272  		return err
273  	}
274  	return NewTypeForEncodingErr(val, "[64]byte")
275  }
276  
277  // DBytes64 is an Decoder for 64-byte arrays. An error is returned if val is not
278  // a *[64]byte.
279  func DBytes64(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
280  	if b, ok := val.(*[64]byte); ok && l == 64 {
281  		_, err := io.ReadFull(r, b[:])
282  		return err
283  	}
284  	return NewTypeForDecodingErr(val, "[64]byte", l, 64)
285  }
286  
287  // EPubKey is an Encoder for *btcec.PublicKey values. An error is returned if
288  // val is not a **btcec.PublicKey.
289  func EPubKey(w io.Writer, val interface{}, _ *[8]byte) error {
290  	if pk, ok := val.(**btcec.PublicKey); ok {
291  		_, err := w.Write((*pk).SerializeCompressed())
292  		return err
293  	}
294  	return NewTypeForEncodingErr(val, "*btcec.PublicKey")
295  }
296  
297  // DPubKey is a Decoder for *btcec.PublicKey values. An error is returned if val
298  // is not a **btcec.PublicKey.
299  func DPubKey(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
300  	if pk, ok := val.(**btcec.PublicKey); ok && l == 33 {
301  		var b [33]byte
302  		_, err := io.ReadFull(r, b[:])
303  		if err != nil {
304  			return err
305  		}
306  
307  		p, err := btcec.ParsePubKey(b[:])
308  		if err != nil {
309  			return err
310  		}
311  
312  		*pk = p
313  
314  		return nil
315  	}
316  	return NewTypeForDecodingErr(val, "*btcec.PublicKey", l, 33)
317  }
318  
319  // EVarBytes is an Encoder for variable byte slices. An error is returned if val
320  // is not *[]byte.
321  func EVarBytes(w io.Writer, val interface{}, _ *[8]byte) error {
322  	if b, ok := val.(*[]byte); ok {
323  		_, err := w.Write(*b)
324  		return err
325  	}
326  	return NewTypeForEncodingErr(val, "[]byte")
327  }
328  
329  // DVarBytes is a Decoder for variable byte slices. An error is returned if val
330  // is not *[]byte.
331  func DVarBytes(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
332  	if b, ok := val.(*[]byte); ok {
333  		*b = make([]byte, l)
334  		_, err := io.ReadFull(r, *b)
335  		return err
336  	}
337  	return NewTypeForDecodingErr(val, "[]byte", l, l)
338  }
339  
340  // EBigSize encodes an uint32 or an uint64 using BigSize format. An error is
341  // returned if val is not either *uint32 or *uint64.
342  func EBigSize(w io.Writer, val interface{}, buf *[8]byte) error {
343  	if i, ok := val.(*uint32); ok {
344  		return WriteVarInt(w, uint64(*i), buf)
345  	}
346  
347  	if i, ok := val.(*uint64); ok {
348  		return WriteVarInt(w, uint64(*i), buf)
349  	}
350  
351  	return NewTypeForEncodingErr(val, "BigSize")
352  }
353  
354  // DBigSize decodes an uint32 or an uint64 using BigSize format. An error is
355  // returned if val is not either *uint32 or *uint64.
356  func DBigSize(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
357  	if i, ok := val.(*uint32); ok {
358  		v, err := ReadVarInt(r, buf)
359  		if err != nil {
360  			return err
361  		}
362  		*i = uint32(v)
363  		return nil
364  	}
365  
366  	if i, ok := val.(*uint64); ok {
367  		v, err := ReadVarInt(r, buf)
368  		if err != nil {
369  			return err
370  		}
371  		*i = v
372  		return nil
373  	}
374  
375  	return NewTypeForDecodingErr(val, "BigSize", l, 8)
376  }
377  
378  // constraintUint32Or64 is a type constraint for uint32 or uint64 types.
379  type constraintUint32Or64 interface {
380  	uint32 | uint64
381  }
382  
383  // SizeBigSize returns a SizeFunc that can compute the length of BigSize.
384  func SizeBigSize[T constraintUint32Or64](val *T) SizeFunc {
385  	var size uint64
386  
387  	switch i := any(val).(type) {
388  	case *uint32:
389  		size = VarIntSize(uint64(*i))
390  	case *uint64:
391  		size = VarIntSize(*i)
392  	default:
393  		panic(fmt.Sprintf("unexpected type %T for SizeBigSize", val))
394  	}
395  
396  	return func() uint64 {
397  		return size
398  	}
399  }