12 // TestTwriteOverflow ensures that a Twrite message will have the data field
13 // truncated if the msize would be exceeded.
14 func TestTwriteOverflow(t *testing.T) {
18 // size[4] Twrite tag[2] fid[4] offset[8] count[4] data[count] | count = 0
19 overhead = 4 + 1 + 2 + 4 + 8 + 4
23 ctx = context.Background()
25 ch = NewChannel(conn, msize)
28 for _, testcase := range []struct {
30 overflow int // amount to overflow the message by.
33 name: "BoundedOverflow",
37 name: "LargeOverflow",
41 name: "HeaderOverflow",
45 name: "HeaderOffsetOverflow",
46 overflow: overhead - 1,
49 name: "OverflowByOne",
54 t.Run(testcase.name, func(t *testing.T) {
56 fcall = overflowMessage(ch.(*channel).codec, msize, testcase.overflow)
57 data = fcall.Message.(MessageTwrite).Data
61 t.Logf("overflow: %v, len(data): %v, expected overflow: %v", testcase.overflow, len(data), overhead+len(data)-msize)
63 if err := ch.WriteFcall(ctx, fcall); err != nil {
67 if err := binary.Read(bytes.NewReader(conn.buf.Bytes()), binary.LittleEndian, &size); err != nil {
72 t.Fatalf("should have truncated size header: %d != %d", size, msize)
75 if conn.buf.Len() != msize {
76 t.Fatalf("should have truncated message: conn.buf.Len(%v) != msize(%v)", conn.buf.Len(), msize)
83 // TestWriteOverflowError ensures that we return an error in cases when there
84 // will certainly be an overflow and it cannot be resolved.
85 func TestWriteOverflowError(t *testing.T) {
88 overflowMSize = msize + 1
92 ctx = context.Background()
94 ch = NewChannel(conn, msize)
95 data = bytes.Repeat([]byte{'A'}, 4)
96 fcall = newFcall(1, MessageTwrite{
99 messageSize = 4 + ch.(*channel).codec.Size(fcall)
102 err := ch.WriteFcall(ctx, fcall)
104 t.Fatal("error expected when overflowing message")
107 if Overflow(err) != messageSize-msize {
108 t.Fatalf("overflow should reflect messageSize and msize, %d != %d", Overflow(err), messageSize-msize)
112 // TestReadOverflow ensures that messages coming over a network connection do
113 // not overflow the msize. Invalid messages will cause `ReadFcall` to return an
115 func TestReadFcallOverflow(t *testing.T) {
121 ctx = context.Background()
123 ch = NewChannel(conn, msize)
124 codec = ch.(*channel).codec
127 for _, testcase := range []struct {
132 name: "OverflowByOne",
136 name: "HeaderOverflow",
137 overflow: overheadMessage(codec, MessageTwrite{}),
140 name: "HeaderOffsetOverflow",
141 overflow: overheadMessage(codec, MessageTwrite{}) - 1,
144 t.Run(testcase.name, func(t *testing.T) {
145 fcall := overflowMessage(codec, msize, testcase.overflow)
147 // prepare the raw message
148 p, err := ch.(*channel).codec.Marshal(fcall)
153 // "send" the message into the buffer
154 // this message is crafted to overflow the read buffer.
155 if err := sendmsg(&conn.buf, p); err != nil {
160 err = ch.ReadFcall(ctx, &incoming)
162 t.Fatal("expected error on fcall")
165 // sanity check to ensure our test code has the right overflow
166 if testcase.overflow != ch.(*channel).msgmsize(fcall)-msize {
167 t.Fatalf("overflow calculation incorrect: %v != %v", testcase.overflow, ch.(*channel).msgmsize(fcall)-msize)
170 if Overflow(err) != testcase.overflow {
171 t.Fatalf("unexpected overflow on error: %v !=%v", Overflow(err), testcase.overflow)
177 // TestTreadRewrite ensures that messages that whose response would overflow
178 // the msize will have be adjusted before sending.
179 func TestTreadRewrite(t *testing.T) {
182 overflowMSize = msize + 1
186 ctx = context.Background()
188 ch = NewChannel(conn, msize)
189 buf = make([]byte, overflowMSize)
190 // data = bytes.Repeat([]byte{'A'}, overflowMSize)
191 fcall = newFcall(1, MessageTread{
192 Count: overflowMSize,
194 responseMSize = ch.(*channel).msgmsize(newFcall(1, MessageRread{
199 if err := ch.WriteFcall(ctx, fcall); err != nil {
203 // just read the message off the buffer
204 n, err := readmsg(&conn.buf, buf)
210 if err := ch.(*channel).codec.Unmarshal(buf[:n], fcall); err != nil {
214 tread, ok := fcall.Message.(MessageTread)
216 t.Fatalf("unexpected message: %v", fcall)
219 if tread.Count != overflowMSize-(uint32(responseMSize)-msize) {
220 t.Fatalf("count not rewritten: %v != %v", tread.Count, overflowMSize-(uint32(responseMSize)-msize))
224 type mockConn struct {
229 func (m mockConn) SetWriteDeadline(t time.Time) error { return nil }
230 func (m mockConn) SetReadDeadline(t time.Time) error { return nil }
232 func (m *mockConn) Write(p []byte) (int, error) {
233 return m.buf.Write(p)
236 func (m *mockConn) Read(p []byte) (int, error) {
240 func overheadMessage(codec Codec, msg Message) int {
241 return 4 + codec.Size(newFcall(1, msg))
244 // overflowMessage returns message that overflows the msize by overflow bytes,
245 // returning the message size and the fcall.
246 func overflowMessage(codec Codec, msize, overflow int) *Fcall {
248 overhead = overheadMessage(codec, MessageTwrite{})
249 data = bytes.Repeat([]byte{'A'}, (msize-overhead)+overflow)
250 fcall = newFcall(1, MessageTwrite{