Blob


1 package p9p
3 import (
4 "bytes"
5 "context"
6 "encoding/binary"
7 "net"
8 "testing"
9 "time"
10 )
12 // TestTwriteOverflow ensures that a Twrite message will have the data field
13 // truncated if the msize would be exceeded.
14 func TestTwriteOverflow(t *testing.T) {
15 const (
16 msize = 512
18 // size[4] Twrite tag[2] fid[4] offset[8] count[4] data[count] | count = 0
19 overhead = 4 + 1 + 2 + 4 + 8 + 4
20 )
22 var (
23 ctx = context.Background()
24 conn = &mockConn{}
25 ch = NewChannel(conn, msize)
26 )
28 for _, testcase := range []struct {
29 name string
30 overflow int // amount to overflow the message by.
31 }{
32 {
33 name: "BoundedOverflow",
34 overflow: msize / 2,
35 },
36 {
37 name: "LargeOverflow",
38 overflow: msize * 3,
39 },
40 {
41 name: "HeaderOverflow",
42 overflow: overhead,
43 },
44 {
45 name: "HeaderOffsetOverflow",
46 overflow: overhead - 1,
47 },
48 {
49 name: "OverflowByOne",
50 overflow: 1,
51 },
52 } {
54 t.Run(testcase.name, func(t *testing.T) {
55 var (
56 fcall = overflowMessage(ch.(*channel).codec, msize, testcase.overflow)
57 data = fcall.Message.(MessageTwrite).Data
58 size uint32
59 )
61 t.Logf("overflow: %v, len(data): %v, expected overflow: %v", testcase.overflow, len(data), overhead+len(data)-msize)
62 conn.buf.Reset()
63 if err := ch.WriteFcall(ctx, fcall); err != nil {
64 t.Fatal(err)
65 }
67 if err := binary.Read(bytes.NewReader(conn.buf.Bytes()), binary.LittleEndian, &size); err != nil {
68 t.Fatal(err)
69 }
71 if size != msize {
72 t.Fatalf("should have truncated size header: %d != %d", size, msize)
73 }
75 if conn.buf.Len() != msize {
76 t.Fatalf("should have truncated message: conn.buf.Len(%v) != msize(%v)", conn.buf.Len(), msize)
77 }
78 })
79 }
81 }
83 // TestWriteOverflowError ensures that we return an error in cases when there
84 // will certainly be an overflow and it cannot be resolved.
85 func TestWriteOverflowError(t *testing.T) {
86 const (
87 msize = 4
88 overflowMSize = msize + 1
89 )
91 var (
92 ctx = context.Background()
93 conn = &mockConn{}
94 ch = NewChannel(conn, msize)
95 data = bytes.Repeat([]byte{'A'}, 4)
96 fcall = newFcall(1, MessageTwrite{
97 Data: data,
98 })
99 messageSize = 4 + ch.(*channel).codec.Size(fcall)
102 err := ch.WriteFcall(ctx, fcall)
103 if err == nil {
104 t.Fatal("error expected when overflowing message")
107 if Overflow(err) != messageSize-msize {
108 t.Fatalf("overflow should reflect messageSize and msize, %d != %d", Overflow(err), messageSize-msize)
112 // TestReadOverflow ensures that messages coming over a network connection do
113 // not overflow the msize. Invalid messages will cause `ReadFcall` to return an
114 // Overflow error.
115 func TestReadFcallOverflow(t *testing.T) {
116 const (
117 msize = 256
120 var (
121 ctx = context.Background()
122 conn = &mockConn{}
123 ch = NewChannel(conn, msize)
124 codec = ch.(*channel).codec
127 for _, testcase := range []struct {
128 name string
129 overflow int
130 }{
132 name: "OverflowByOne",
133 overflow: 1,
134 },
136 name: "HeaderOverflow",
137 overflow: overheadMessage(codec, MessageTwrite{}),
138 },
140 name: "HeaderOffsetOverflow",
141 overflow: overheadMessage(codec, MessageTwrite{}) - 1,
142 },
143 } {
144 t.Run(testcase.name, func(t *testing.T) {
145 fcall := overflowMessage(codec, msize, testcase.overflow)
147 // prepare the raw message
148 p, err := ch.(*channel).codec.Marshal(fcall)
149 if err != nil {
150 t.Fatal(err)
153 // "send" the message into the buffer
154 // this message is crafted to overflow the read buffer.
155 if err := sendmsg(&conn.buf, p); err != nil {
156 t.Fatal(err)
159 var incoming Fcall
160 err = ch.ReadFcall(ctx, &incoming)
161 if err == nil {
162 t.Fatal("expected error on fcall")
165 // sanity check to ensure our test code has the right overflow
166 if testcase.overflow != ch.(*channel).msgmsize(fcall)-msize {
167 t.Fatalf("overflow calculation incorrect: %v != %v", testcase.overflow, ch.(*channel).msgmsize(fcall)-msize)
170 if Overflow(err) != testcase.overflow {
171 t.Fatalf("unexpected overflow on error: %v !=%v", Overflow(err), testcase.overflow)
173 })
177 // TestTreadRewrite ensures that messages that whose response would overflow
178 // the msize will have be adjusted before sending.
179 func TestTreadRewrite(t *testing.T) {
180 const (
181 msize = 256
182 overflowMSize = msize + 1
185 var (
186 ctx = context.Background()
187 conn = &mockConn{}
188 ch = NewChannel(conn, msize)
189 buf = make([]byte, overflowMSize)
190 // data = bytes.Repeat([]byte{'A'}, overflowMSize)
191 fcall = newFcall(1, MessageTread{
192 Count: overflowMSize,
193 })
194 responseMSize = ch.(*channel).msgmsize(newFcall(1, MessageRread{
195 Data: buf,
196 }))
199 if err := ch.WriteFcall(ctx, fcall); err != nil {
200 t.Fatal(err)
203 // just read the message off the buffer
204 n, err := readmsg(&conn.buf, buf)
205 if err != nil {
206 t.Fatal(err)
209 *fcall = Fcall{}
210 if err := ch.(*channel).codec.Unmarshal(buf[:n], fcall); err != nil {
211 t.Fatal(err)
214 tread, ok := fcall.Message.(MessageTread)
215 if !ok {
216 t.Fatalf("unexpected message: %v", fcall)
219 if tread.Count != overflowMSize-(uint32(responseMSize)-msize) {
220 t.Fatalf("count not rewritten: %v != %v", tread.Count, overflowMSize-(uint32(responseMSize)-msize))
224 type mockConn struct {
225 net.Conn
226 buf bytes.Buffer
229 func (m mockConn) SetWriteDeadline(t time.Time) error { return nil }
230 func (m mockConn) SetReadDeadline(t time.Time) error { return nil }
232 func (m *mockConn) Write(p []byte) (int, error) {
233 return m.buf.Write(p)
236 func (m *mockConn) Read(p []byte) (int, error) {
237 return m.buf.Read(p)
240 func overheadMessage(codec Codec, msg Message) int {
241 return 4 + codec.Size(newFcall(1, msg))
244 // overflowMessage returns message that overflows the msize by overflow bytes,
245 // returning the message size and the fcall.
246 func overflowMessage(codec Codec, msize, overflow int) *Fcall {
247 var (
248 overhead = overheadMessage(codec, MessageTwrite{})
249 data = bytes.Repeat([]byte{'A'}, (msize-overhead)+overflow)
250 fcall = newFcall(1, MessageTwrite{
251 Data: data,
252 })
255 return fcall