commit 0f5f58bba93f6b0a435fdf59157ef453a2291ba1 from: Stephen J Day date: Wed Nov 16 00:47:26 2016 UTC channel: truncate twrite messages based on msize While there are a few problems around handling of msize, the easiest to address and, arguably, the most problematic is that of Twrite. We now truncate Twrite.Data to the correct length if it will overflow the msize limit negotiated on the session. ErrShortWrite is returned by the `Session.Write` method if written data is truncated. In addition, we now reject incoming messages from `ReadFcall` that overflow the msize. Such messages are probably terminal in practice, but can be detected with the `Overflow` function. Tread is also handled accordingly, such that the Count field will be rewritten such that the response doesn't overflow the msize. Signed-off-by: Stephen J Day commit - 529e2b2efc36aab8a17e32ba2516fc89d4cbd43a commit + 0f5f58bba93f6b0a435fdf59157ef453a2291ba1 blob - 82e719d61f2683fee7c710f0b6f02a3856a2679c blob + 984958022d113b06920ca64f2defedba90925620 --- channel.go +++ channel.go @@ -2,15 +2,19 @@ package p9p import ( "bufio" + "context" "encoding/binary" - "fmt" "io" "io/ioutil" "log" "net" "time" +) - "context" +const ( + // channelMessageHeaderSize is the overhead for sending the size of a + // message on the wire. + channelMessageHeaderSize = 4 ) // Channel defines the operations necessary to implement a 9p message channel @@ -114,6 +118,9 @@ func (ch *channel) SetMSize(msize int) { } // ReadFcall reads the next message from the channel into fcall. +// +// If the incoming message overflows the msize, Overflow(err) will return +// nonzero with the number of bytes overflowed. func (ch *channel) ReadFcall(ctx context.Context, fcall *Fcall) error { select { case <-ctx.Done(): @@ -140,9 +147,7 @@ func (ch *channel) ReadFcall(ctx context.Context, fcal } if n > len(ch.rdbuf) { - // TODO(stevvooe): Make this error detectable and respond with error - // message. - return fmt.Errorf("message too large for buffer: %v > %v ", n, len(ch.rdbuf)) + return overflowErr{size: n - len(ch.rdbuf)} } // clear out the fcall @@ -151,9 +156,19 @@ func (ch *channel) ReadFcall(ctx context.Context, fcal return err } + if err := ch.maybeTruncate(fcall); err != nil { + return err + } + return nil } +// WriteFcall writes the message to the connection. +// +// If a message destined for the wire will overflow MSize, an Overflow error +// may be returned. For Twrite calls, the buffer will simply be truncated to +// the optimal msize, with the caller detecting this condition with +// Rwrite.Count. func (ch *channel) WriteFcall(ctx context.Context, fcall *Fcall) error { select { case <-ctx.Done(): @@ -172,6 +187,10 @@ func (ch *channel) WriteFcall(ctx context.Context, fca log.Printf("transport: error setting read deadline on %v: %v", ch.conn.RemoteAddr(), err) } + if err := ch.maybeTruncate(fcall); err != nil { + return err + } + p, err := ch.codec.Marshal(fcall) if err != nil { return err @@ -184,6 +203,86 @@ func (ch *channel) WriteFcall(ctx context.Context, fca return ch.bwr.Flush() } +// maybeTruncate will truncate the message to fit into msize on the wire, if +// possible, or modify the message to ensure the response won't overflow. +// +// If the message cannot be truncated, an error will be returned and the +// message should not be sent. +// +// A nil return value means the message can be sent without +func (ch *channel) maybeTruncate(fcall *Fcall) error { + + // for certain message types, just remove the extra bytes from the data portion. + switch msg := fcall.Message.(type) { + // TODO(stevvooe): There is one more problematic message type: + // + // Rread: while we can employ the same truncation fix as Twrite, we + // need to make it observable to upstream handlers. + + case MessageTread: + // We can rewrite msg.Count so that a return message will be under + // msize. This is more defensive than anything but will ensure that + // calls don't fail on sloppy servers. + + // first, craft the shape of the response message + resp := newFcall(fcall.Tag, MessageRread{}) + overflow := uint32(ch.msgmsize(resp)) + msg.Count - uint32(ch.msize) + + if msg.Count < overflow { + // Let the bad thing happen; msize too small to even support valid + // rewrite. This will result in a Terror from the server-side or + // just work. + return nil + } + + msg.Count -= overflow + fcall.Message = msg + + return nil + case MessageTwrite: + // If we are going to overflow the msize, we need to truncate the write to + // appropriate size or throw an error in all other conditions. + size := ch.msgmsize(fcall) + if size <= ch.msize { + return nil + } + + // overflow the msize, including the channel message size fields. + overflow := size - ch.msize + + if len(msg.Data) < overflow { + // paranoid: if msg.Data is not big enough to handle the + // overflow, we should get an overflow error. MSize would have + // to be way too small to be realistic. + return overflowErr{size: overflow} + } + + // The truncation is reflected in the return message (Rwrite) by + // the server, so we don't need a return value or error condition + // to communicate it. + msg.Data = msg.Data[:len(msg.Data)-overflow] + fcall.Message = msg // since we have a local copy + + return nil + default: + size := ch.msgmsize(fcall) + if size > ch.msize { + // overflow the msize, including the channel message size fields. + return overflowErr{size: size - ch.msize} + } + + return nil + } + +} + +// msgmsize returns the on-wire msize of the Fcall, including the size header. +// Typically, this can be used to detect whether or not the message overflows +// the msize buffer. +func (ch *channel) msgmsize(fcall *Fcall) int { + return channelMessageHeaderSize + ch.codec.Size(fcall) +} + // readmsg reads a 9p message into p from rd, ensuring that all bytes are // consumed from the size header. If the size header indicates the message is // larger than p, the entire message will be discarded, leaving a truncated blob - 0a06e0cce5533ef9907949cdef4d3564a63b3408 blob + 5884b06bf362b07fce7bf0404a289c16a7f88531 --- client.go +++ client.go @@ -1,6 +1,7 @@ package p9p import ( + "io" "net" "context" @@ -149,7 +150,17 @@ func (c *client) Read(ctx context.Context, fid Fid, p return 0, ErrUnexpectedMsg } - return copy(p, rread.Data), nil + n = copy(p, rread.Data) + switch { + case len(rread.Data) == 0: + err = io.EOF + case n < len(p): + // TODO(stevvooe): Technically, we should treat this as an io.EOF. + // However, we cannot tell if the short read was due to EOF or due to + // truncation. + } + + return n, err } func (c *client) Write(ctx context.Context, fid Fid, p []byte, offset int64) (n int, err error) { @@ -167,7 +178,11 @@ func (c *client) Write(ctx context.Context, fid Fid, p return 0, ErrUnexpectedMsg } - return int(rwrite.Count), nil + if int(rwrite.Count) < len(p) { + err = io.ErrShortWrite + } + + return int(rwrite.Count), err } func (c *client) Open(ctx context.Context, fid Fid, mode Flag) (Qid, uint32, error) { blob - /dev/null blob + fa6137bef1462d6f9d0c392e1a53bf455e63df02 (mode 644) --- /dev/null +++ channel_test.go @@ -0,0 +1,256 @@ +package p9p + +import ( + "bytes" + "context" + "encoding/binary" + "net" + "testing" + "time" +) + +// TestTwriteOverflow ensures that a Twrite message will have the data field +// truncated if the msize would be exceeded. +func TestTwriteOverflow(t *testing.T) { + const ( + msize = 512 + + // size[4] Twrite tag[2] fid[4] offset[8] count[4] data[count] | count = 0 + overhead = 4 + 1 + 2 + 4 + 8 + 4 + ) + + var ( + ctx = context.Background() + conn = &mockConn{} + ch = NewChannel(conn, msize) + ) + + for _, testcase := range []struct { + name string + overflow int // amount to overflow the message by. + }{ + { + name: "BoundedOverflow", + overflow: msize / 2, + }, + { + name: "LargeOverflow", + overflow: msize * 3, + }, + { + name: "HeaderOverflow", + overflow: overhead, + }, + { + name: "HeaderOffsetOverflow", + overflow: overhead - 1, + }, + { + name: "OverflowByOne", + overflow: 1, + }, + } { + + t.Run(testcase.name, func(t *testing.T) { + var ( + fcall = overflowMessage(ch.(*channel).codec, msize, testcase.overflow) + data = fcall.Message.(MessageTwrite).Data + size uint32 + ) + + t.Logf("overflow: %v, len(data): %v, expected overflow: %v", testcase.overflow, len(data), overhead+len(data)-msize) + conn.buf.Reset() + if err := ch.WriteFcall(ctx, fcall); err != nil { + t.Fatal(err) + } + + if err := binary.Read(bytes.NewReader(conn.buf.Bytes()), binary.LittleEndian, &size); err != nil { + t.Fatal(err) + } + + if size != msize { + t.Fatalf("should have truncated size header: %d != %d", size, msize) + } + + if conn.buf.Len() != msize { + t.Fatalf("should have truncated message: conn.buf.Len(%v) != msize(%v)", conn.buf.Len(), msize) + } + }) + } + +} + +// TestWriteOverflowError ensures that we return an error in cases when there +// will certainly be an overflow and it cannot be resolved. +func TestWriteOverflowError(t *testing.T) { + const ( + msize = 4 + overflowMSize = msize + 1 + ) + + var ( + ctx = context.Background() + conn = &mockConn{} + ch = NewChannel(conn, msize) + data = bytes.Repeat([]byte{'A'}, 4) + fcall = newFcall(1, MessageTwrite{ + Data: data, + }) + messageSize = 4 + ch.(*channel).codec.Size(fcall) + ) + + err := ch.WriteFcall(ctx, fcall) + if err == nil { + t.Fatal("error expected when overflowing message") + } + + if Overflow(err) != messageSize-msize { + t.Fatalf("overflow should reflect messageSize and msize, %d != %d", Overflow(err), messageSize-msize) + } +} + +// TestReadOverflow ensures that messages coming over a network connection do +// not overflow the msize. Invalid messages will cause `ReadFcall` to return an +// Overflow error. +func TestReadFcallOverflow(t *testing.T) { + const ( + msize = 256 + ) + + var ( + ctx = context.Background() + conn = &mockConn{} + ch = NewChannel(conn, msize) + codec = ch.(*channel).codec + ) + + for _, testcase := range []struct { + name string + overflow int + }{ + { + name: "OverflowByOne", + overflow: 1, + }, + { + name: "HeaderOverflow", + overflow: overheadMessage(codec, MessageTwrite{}), + }, + { + name: "HeaderOffsetOverflow", + overflow: overheadMessage(codec, MessageTwrite{}) - 1, + }, + } { + t.Run(testcase.name, func(t *testing.T) { + fcall := overflowMessage(codec, msize, testcase.overflow) + + // prepare the raw message + p, err := ch.(*channel).codec.Marshal(fcall) + if err != nil { + t.Fatal(err) + } + + // "send" the message into the buffer + // this message is crafted to overflow the read buffer. + if err := sendmsg(&conn.buf, p); err != nil { + t.Fatal(err) + } + + var incoming Fcall + err = ch.ReadFcall(ctx, &incoming) + if err == nil { + t.Fatal("expected error on fcall") + } + + // sanity check to ensure our test code has the right overflow + if testcase.overflow != ch.(*channel).msgmsize(fcall)-msize { + t.Fatalf("overflow calculation incorrect: %v != %v", testcase.overflow, ch.(*channel).msgmsize(fcall)-msize) + } + + if Overflow(err) != testcase.overflow { + t.Fatalf("unexpected overflow on error: %v !=%v", Overflow(err), testcase.overflow) + } + }) + } +} + +// TestTreadRewrite ensures that messages that whose response would overflow +// the msize will have be adjusted before sending. +func TestTreadRewrite(t *testing.T) { + const ( + msize = 256 + overflowMSize = msize + 1 + ) + + var ( + ctx = context.Background() + conn = &mockConn{} + ch = NewChannel(conn, msize) + buf = make([]byte, overflowMSize) + // data = bytes.Repeat([]byte{'A'}, overflowMSize) + fcall = newFcall(1, MessageTread{ + Count: overflowMSize, + }) + responseMSize = ch.(*channel).msgmsize(newFcall(1, MessageRread{ + Data: buf, + })) + ) + + if err := ch.WriteFcall(ctx, fcall); err != nil { + t.Fatal(err) + } + + // just read the message off the buffer + n, err := readmsg(&conn.buf, buf) + if err != nil { + t.Fatal(err) + } + + *fcall = Fcall{} + if err := ch.(*channel).codec.Unmarshal(buf[:n], fcall); err != nil { + t.Fatal(err) + } + + tread, ok := fcall.Message.(MessageTread) + if !ok { + t.Fatalf("unexpected message: %v", fcall) + } + + if tread.Count != overflowMSize-(uint32(responseMSize)-msize) { + t.Fatalf("count not rewritten: %v != %v", tread.Count, overflowMSize-(uint32(responseMSize)-msize)) + } +} + +type mockConn struct { + net.Conn + buf bytes.Buffer +} + +func (m mockConn) SetWriteDeadline(t time.Time) error { return nil } +func (m mockConn) SetReadDeadline(t time.Time) error { return nil } + +func (m *mockConn) Write(p []byte) (int, error) { + return m.buf.Write(p) +} + +func (m *mockConn) Read(p []byte) (int, error) { + return m.buf.Read(p) +} + +func overheadMessage(codec Codec, msg Message) int { + return 4 + codec.Size(newFcall(1, msg)) +} + +// overflowMessage returns message that overflows the msize by overflow bytes, +// returning the message size and the fcall. +func overflowMessage(codec Codec, msize, overflow int) *Fcall { + var ( + overhead = overheadMessage(codec, MessageTwrite{}) + data = bytes.Repeat([]byte{'A'}, (msize-overhead)+overflow) + fcall = newFcall(1, MessageTwrite{ + Data: data, + }) + ) + + return fcall +} blob - 417aed22503a0e939ccfa2597bec7c69ca8b0ddf blob + 776e851f7bd9ecbf7ed8508af1878e5dca60d486 --- server.go +++ server.go @@ -132,6 +132,9 @@ func (c *conn) serve() error { } go func(ctx context.Context, req *Fcall) { + // TODO(stevvooe): Re-write incoming Treads so that handler + // can always respond with a message of the correct msize. + var resp *Fcall msg, err := c.handler.Handle(ctx, req.Message) if err != nil { @@ -208,6 +211,11 @@ func (c *conn) write(responses chan *Fcall) { for { select { case resp := <-responses: + // TODO(stevvooe): Correctly protect againt overflowing msize from + // handler. This can be done above, in the main message handler + // loop, by adjusting incoming Tread calls to have a Count that + // won't overflow the msize. + if err := c.ch.WriteFcall(c.ctx, resp); err != nil { if err, ok := err.(net.Error); ok { if err.Timeout() || err.Temporary() { blob - /dev/null blob + ab4b7755ffe5d70016fbd5d03677a5a3a922243b (mode 644) --- /dev/null +++ overflow.go @@ -0,0 +1,49 @@ +package p9p + +import "fmt" + +// Overflow will return a positive number, indicating there was an overflow for +// the error. +func Overflow(err error) int { + if of, ok := err.(overflow); ok { + return of.Size() + } + + // traverse cause, if above fails. + if causal, ok := err.(interface { + Cause() error + }); ok { + return Overflow(causal.Cause()) + } + + return 0 +} + +// overflow is a resolvable error type that can help callers negotiate +// session msize. If this error is encountered, no message was sent. +// +// The return value of `Size()` represents the number of bytes that would have +// been truncated if the message were sent. This IS NOT the optimal buffer size +// for operations like read and write. +// +// In the case of `Twrite`, the caller can Size() from the local size to get an +// optimally size buffer or the write can simply be truncated to `len(buf) - +// err.Size()`. +// +// For the most part, no users of this package should see this error in +// practice. If this escapes the Session interface, it is a bug. +type overflow interface { + Size() int // number of bytes overflowed. +} + +type overflowErr struct { + size int // number of bytes overflowed +} + +func (o overflowErr) Error() string { + return fmt.Sprintf("message overflowed %d bytes", o.size) +} + +func (o overflowErr) Size() int { + return o.size +} blob - 429cdab15ed44a658f612e6be6001cedc23dc7dc blob + eaca8d7018d33da28e0b906606c01aa64d360bfd --- session.go +++ session.go @@ -19,8 +19,17 @@ type Session interface { Clunk(ctx context.Context, fid Fid) error Remove(ctx context.Context, fid Fid) error Walk(ctx context.Context, fid Fid, newfid Fid, names ...string) ([]Qid, error) + + // Read follows the semantics of io.ReaderAt.ReadAtt method except it takes + // a contxt and Fid. Read(ctx context.Context, fid Fid, p []byte, offset int64) (n int, err error) + + // Write follows the semantics of io.WriterAt.WriteAt except takes a context and an Fid. + // + // If n == len(p), no error is returned. + // If n < len(p), io.ErrShortWrite will be returned. Write(ctx context.Context, fid Fid, p []byte, offset int64) (n int, err error) + Open(ctx context.Context, fid Fid, mode Flag) (Qid, uint32, error) Create(ctx context.Context, parent Fid, name string, perm uint32, mode Flag) (Qid, uint32, error) Stat(ctx context.Context, fid Fid) (Dir, error)