summaryrefslogtreecommitdiff
path: root/errors/marshal.go
blob: 7c8a63ef92031429adede48e2392356b8f3e5025 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
package errors

import (
	"encoding/binary"
	"log"
)

func (e *Error) MarshalAppend(b []byte) []byte {
	if e == nil {
		return b
	}

	b = appendString(b, string(e.Op))

	var tmp [16]byte
	N := binary.PutVarint(tmp[:], int64(e.Kind))
	b = append(b, tmp[:N]...)
	b = MarshalErrorAppend(e.Err, b)
	return b
}

func (e *Error) MarshalBinary() ([]byte, error) {
	return e.MarshalAppend(nil), nil
}

func MarshalErrorAppend(err error, b []byte) []byte {
	if err == nil {
		return b
	}
	if e, ok := err.(*Error); ok {
		b = append(b, 'E')
		return e.MarshalAppend(b)
	}
	// Ordinary error.
	b = append(b, 'e')
	b = appendString(b, err.Error())
	return b
}

func MarshalError(err error) []byte {
	return MarshalErrorAppend(err, nil)
}

func (e *Error) UnmarshalBinary(b []byte) error {
	if len(b) == 0 {
		return nil
	}
	data, b := getBytes(b)
	if data != nil {
		e.Op = Op(data)
	}
	k, N := binary.Varint(b)
	e.Kind = Kind(k)
	b = b[N:]
	e.Err = UnmarshalError(b)
	return nil
}

func UnmarshalError(b []byte) error {
	if len(b) == 0 {
		return nil
	}
	code := b[0]
	b = b[1:]
	switch code {
	case 'e':
		var data []byte
		data, b = getBytes(b)
		if len(b) != 0 {
			log.Printf("Unmarshal error: trailing bytes")
		}
		return Str(string(data))
	case 'E':
		var err Error
		err.UnmarshalBinary(b)
		return &err
	default:
		log.Printf("Unmarshal error: corrupt data %q", b)
		return Str(string(b))
	}
}

func appendString(b []byte, str string) []byte {
	var tmp [16]byte
	N := binary.PutUvarint(tmp[:], uint64(len(str)))
	b = append(b, tmp[:N]...)
	b = append(b, str...)
	return b
}

func getBytes(b []byte) (data, remaining []byte) {
	u, N := binary.Uvarint(b)
	if len(b) < N+int(u) {
		log.Printf("Unmarshal error: bad encoding")
		return nil, nil
	}
	if N == 0 {
		log.Printf("Unmarshal error: bad encoding")
		return nil, b
	}
	return b[N : N+int(u)], b[N+int(u):]
}