commit c74282f87c41aad90bbd9edc99c3c1fe45d65953 from: Stephen Day via: GitHub date: Wed Nov 16 01:05:55 2016 UTC Merge pull request #30 from stevvooe/truncate-twrite-msize channel: truncate twrite messages based on msize commit - 529e2b2efc36aab8a17e32ba2516fc89d4cbd43a commit + c74282f87c41aad90bbd9edc99c3c1fe45d65953 blob - 82e719d61f2683fee7c710f0b6f02a3856a2679c blob + 984958022d113b06920ca64f2defedba90925620 --- channel.go +++ channel.go @@ -2,15 +2,19 @@ package p9p import ( "bufio" + "context" "encoding/binary" - "fmt" "io" "io/ioutil" "log" "net" "time" +) - "context" +const ( + // channelMessageHeaderSize is the overhead for sending the size of a + // message on the wire. + channelMessageHeaderSize = 4 ) // Channel defines the operations necessary to implement a 9p message channel @@ -114,6 +118,9 @@ func (ch *channel) SetMSize(msize int) { } // ReadFcall reads the next message from the channel into fcall. +// +// If the incoming message overflows the msize, Overflow(err) will return +// nonzero with the number of bytes overflowed. func (ch *channel) ReadFcall(ctx context.Context, fcall *Fcall) error { select { case <-ctx.Done(): @@ -140,9 +147,7 @@ func (ch *channel) ReadFcall(ctx context.Context, fcal } if n > len(ch.rdbuf) { - // TODO(stevvooe): Make this error detectable and respond with error - // message. - return fmt.Errorf("message too large for buffer: %v > %v ", n, len(ch.rdbuf)) + return overflowErr{size: n - len(ch.rdbuf)} } // clear out the fcall @@ -151,9 +156,19 @@ func (ch *channel) ReadFcall(ctx context.Context, fcal return err } + if err := ch.maybeTruncate(fcall); err != nil { + return err + } + return nil } +// WriteFcall writes the message to the connection. +// +// If a message destined for the wire will overflow MSize, an Overflow error +// may be returned. For Twrite calls, the buffer will simply be truncated to +// the optimal msize, with the caller detecting this condition with +// Rwrite.Count. func (ch *channel) WriteFcall(ctx context.Context, fcall *Fcall) error { select { case <-ctx.Done(): @@ -172,6 +187,10 @@ func (ch *channel) WriteFcall(ctx context.Context, fca log.Printf("transport: error setting read deadline on %v: %v", ch.conn.RemoteAddr(), err) } + if err := ch.maybeTruncate(fcall); err != nil { + return err + } + p, err := ch.codec.Marshal(fcall) if err != nil { return err @@ -184,6 +203,86 @@ func (ch *channel) WriteFcall(ctx context.Context, fca return ch.bwr.Flush() } +// maybeTruncate will truncate the message to fit into msize on the wire, if +// possible, or modify the message to ensure the response won't overflow. +// +// If the message cannot be truncated, an error will be returned and the +// message should not be sent. +// +// A nil return value means the message can be sent without +func (ch *channel) maybeTruncate(fcall *Fcall) error { + + // for certain message types, just remove the extra bytes from the data portion. + switch msg := fcall.Message.(type) { + // TODO(stevvooe): There is one more problematic message type: + // + // Rread: while we can employ the same truncation fix as Twrite, we + // need to make it observable to upstream handlers. + + case MessageTread: + // We can rewrite msg.Count so that a return message will be under + // msize. This is more defensive than anything but will ensure that + // calls don't fail on sloppy servers. + + // first, craft the shape of the response message + resp := newFcall(fcall.Tag, MessageRread{}) + overflow := uint32(ch.msgmsize(resp)) + msg.Count - uint32(ch.msize) + + if msg.Count < overflow { + // Let the bad thing happen; msize too small to even support valid + // rewrite. This will result in a Terror from the server-side or + // just work. + return nil + } + + msg.Count -= overflow + fcall.Message = msg + + return nil + case MessageTwrite: + // If we are going to overflow the msize, we need to truncate the write to + // appropriate size or throw an error in all other conditions. + size := ch.msgmsize(fcall) + if size <= ch.msize { + return nil + } + + // overflow the msize, including the channel message size fields. + overflow := size - ch.msize + + if len(msg.Data) < overflow { + // paranoid: if msg.Data is not big enough to handle the + // overflow, we should get an overflow error. MSize would have + // to be way too small to be realistic. + return overflowErr{size: overflow} + } + + // The truncation is reflected in the return message (Rwrite) by + // the server, so we don't need a return value or error condition + // to communicate it. + msg.Data = msg.Data[:len(msg.Data)-overflow] + fcall.Message = msg // since we have a local copy + + return nil + default: + size := ch.msgmsize(fcall) + if size > ch.msize { + // overflow the msize, including the channel message size fields. + return overflowErr{size: size - ch.msize} + } + + return nil + } + +} + +// msgmsize returns the on-wire msize of the Fcall, including the size header. +// Typically, this can be used to detect whether or not the message overflows +// the msize buffer. +func (ch *channel) msgmsize(fcall *Fcall) int { + return channelMessageHeaderSize + ch.codec.Size(fcall) +} + // readmsg reads a 9p message into p from rd, ensuring that all bytes are // consumed from the size header. If the size header indicates the message is // larger than p, the entire message will be discarded, leaving a truncated blob - 0a06e0cce5533ef9907949cdef4d3564a63b3408 blob + 5884b06bf362b07fce7bf0404a289c16a7f88531 --- client.go +++ client.go @@ -1,6 +1,7 @@ package p9p import ( + "io" "net" "context" @@ -149,7 +150,17 @@ func (c *client) Read(ctx context.Context, fid Fid, p return 0, ErrUnexpectedMsg } - return copy(p, rread.Data), nil + n = copy(p, rread.Data) + switch { + case len(rread.Data) == 0: + err = io.EOF + case n < len(p): + // TODO(stevvooe): Technically, we should treat this as an io.EOF. + // However, we cannot tell if the short read was due to EOF or due to + // truncation. + } + + return n, err } func (c *client) Write(ctx context.Context, fid Fid, p []byte, offset int64) (n int, err error) { @@ -167,7 +178,11 @@ func (c *client) Write(ctx context.Context, fid Fid, p return 0, ErrUnexpectedMsg } - return int(rwrite.Count), nil + if int(rwrite.Count) < len(p) { + err = io.ErrShortWrite + } + + return int(rwrite.Count), err } func (c *client) Open(ctx context.Context, fid Fid, mode Flag) (Qid, uint32, error) { blob - /dev/null blob + fa6137bef1462d6f9d0c392e1a53bf455e63df02 (mode 644) --- /dev/null +++ channel_test.go @@ -0,0 +1,256 @@ +package p9p + +import ( + "bytes" + "context" + "encoding/binary" + "net" + "testing" + "time" +) + +// TestTwriteOverflow ensures that a Twrite message will have the data field +// truncated if the msize would be exceeded. +func TestTwriteOverflow(t *testing.T) { + const ( + msize = 512 + + // size[4] Twrite tag[2] fid[4] offset[8] count[4] data[count] | count = 0 + overhead = 4 + 1 + 2 + 4 + 8 + 4 + ) + + var ( + ctx = context.Background() + conn = &mockConn{} + ch = NewChannel(conn, msize) + ) + + for _, testcase := range []struct { + name string + overflow int // amount to overflow the message by. + }{ + { + name: "BoundedOverflow", + overflow: msize / 2, + }, + { + name: "LargeOverflow", + overflow: msize * 3, + }, + { + name: "HeaderOverflow", + overflow: overhead, + }, + { + name: "HeaderOffsetOverflow", + overflow: overhead - 1, + }, + { + name: "OverflowByOne", + overflow: 1, + }, + } { + + t.Run(testcase.name, func(t *testing.T) { + var ( + fcall = overflowMessage(ch.(*channel).codec, msize, testcase.overflow) + data = fcall.Message.(MessageTwrite).Data + size uint32 + ) + + t.Logf("overflow: %v, len(data): %v, expected overflow: %v", testcase.overflow, len(data), overhead+len(data)-msize) + conn.buf.Reset() + if err := ch.WriteFcall(ctx, fcall); err != nil { + t.Fatal(err) + } + + if err := binary.Read(bytes.NewReader(conn.buf.Bytes()), binary.LittleEndian, &size); err != nil { + t.Fatal(err) + } + + if size != msize { + t.Fatalf("should have truncated size header: %d != %d", size, msize) + } + + if conn.buf.Len() != msize { + t.Fatalf("should have truncated message: conn.buf.Len(%v) != msize(%v)", conn.buf.Len(), msize) + } + }) + } + +} + +// TestWriteOverflowError ensures that we return an error in cases when there +// will certainly be an overflow and it cannot be resolved. +func TestWriteOverflowError(t *testing.T) { + const ( + msize = 4 + overflowMSize = msize + 1 + ) + + var ( + ctx = context.Background() + conn = &mockConn{} + ch = NewChannel(conn, msize) + data = bytes.Repeat([]byte{'A'}, 4) + fcall = newFcall(1, MessageTwrite{ + Data: data, + }) + messageSize = 4 + ch.(*channel).codec.Size(fcall) + ) + + err := ch.WriteFcall(ctx, fcall) + if err == nil { + t.Fatal("error expected when overflowing message") + } + + if Overflow(err) != messageSize-msize { + t.Fatalf("overflow should reflect messageSize and msize, %d != %d", Overflow(err), messageSize-msize) + } +} + +// TestReadOverflow ensures that messages coming over a network connection do +// not overflow the msize. Invalid messages will cause `ReadFcall` to return an +// Overflow error. +func TestReadFcallOverflow(t *testing.T) { + const ( + msize = 256 + ) + + var ( + ctx = context.Background() + conn = &mockConn{} + ch = NewChannel(conn, msize) + codec = ch.(*channel).codec + ) + + for _, testcase := range []struct { + name string + overflow int + }{ + { + name: "OverflowByOne", + overflow: 1, + }, + { + name: "HeaderOverflow", + overflow: overheadMessage(codec, MessageTwrite{}), + }, + { + name: "HeaderOffsetOverflow", + overflow: overheadMessage(codec, MessageTwrite{}) - 1, + }, + } { + t.Run(testcase.name, func(t *testing.T) { + fcall := overflowMessage(codec, msize, testcase.overflow) + + // prepare the raw message + p, err := ch.(*channel).codec.Marshal(fcall) + if err != nil { + t.Fatal(err) + } + + // "send" the message into the buffer + // this message is crafted to overflow the read buffer. + if err := sendmsg(&conn.buf, p); err != nil { + t.Fatal(err) + } + + var incoming Fcall + err = ch.ReadFcall(ctx, &incoming) + if err == nil { + t.Fatal("expected error on fcall") + } + + // sanity check to ensure our test code has the right overflow + if testcase.overflow != ch.(*channel).msgmsize(fcall)-msize { + t.Fatalf("overflow calculation incorrect: %v != %v", testcase.overflow, ch.(*channel).msgmsize(fcall)-msize) + } + + if Overflow(err) != testcase.overflow { + t.Fatalf("unexpected overflow on error: %v !=%v", Overflow(err), testcase.overflow) + } + }) + } +} + +// TestTreadRewrite ensures that messages that whose response would overflow +// the msize will have be adjusted before sending. +func TestTreadRewrite(t *testing.T) { + const ( + msize = 256 + overflowMSize = msize + 1 + ) + + var ( + ctx = context.Background() + conn = &mockConn{} + ch = NewChannel(conn, msize) + buf = make([]byte, overflowMSize) + // data = bytes.Repeat([]byte{'A'}, overflowMSize) + fcall = newFcall(1, MessageTread{ + Count: overflowMSize, + }) + responseMSize = ch.(*channel).msgmsize(newFcall(1, MessageRread{ + Data: buf, + })) + ) + + if err := ch.WriteFcall(ctx, fcall); err != nil { + t.Fatal(err) + } + + // just read the message off the buffer + n, err := readmsg(&conn.buf, buf) + if err != nil { + t.Fatal(err) + } + + *fcall = Fcall{} + if err := ch.(*channel).codec.Unmarshal(buf[:n], fcall); err != nil { + t.Fatal(err) + } + + tread, ok := fcall.Message.(MessageTread) + if !ok { + t.Fatalf("unexpected message: %v", fcall) + } + + if tread.Count != overflowMSize-(uint32(responseMSize)-msize) { + t.Fatalf("count not rewritten: %v != %v", tread.Count, overflowMSize-(uint32(responseMSize)-msize)) + } +} + +type mockConn struct { + net.Conn + buf bytes.Buffer +} + +func (m mockConn) SetWriteDeadline(t time.Time) error { return nil } +func (m mockConn) SetReadDeadline(t time.Time) error { return nil } + +func (m *mockConn) Write(p []byte) (int, error) { + return m.buf.Write(p) +} + +func (m *mockConn) Read(p []byte) (int, error) { + return m.buf.Read(p) +} + +func overheadMessage(codec Codec, msg Message) int { + return 4 + codec.Size(newFcall(1, msg)) +} + +// overflowMessage returns message that overflows the msize by overflow bytes, +// returning the message size and the fcall. +func overflowMessage(codec Codec, msize, overflow int) *Fcall { + var ( + overhead = overheadMessage(codec, MessageTwrite{}) + data = bytes.Repeat([]byte{'A'}, (msize-overhead)+overflow) + fcall = newFcall(1, MessageTwrite{ + Data: data, + }) + ) + + return fcall +} blob - 417aed22503a0e939ccfa2597bec7c69ca8b0ddf blob + 776e851f7bd9ecbf7ed8508af1878e5dca60d486 --- server.go +++ server.go @@ -132,6 +132,9 @@ func (c *conn) serve() error { } go func(ctx context.Context, req *Fcall) { + // TODO(stevvooe): Re-write incoming Treads so that handler + // can always respond with a message of the correct msize. + var resp *Fcall msg, err := c.handler.Handle(ctx, req.Message) if err != nil { @@ -208,6 +211,11 @@ func (c *conn) write(responses chan *Fcall) { for { select { case resp := <-responses: + // TODO(stevvooe): Correctly protect againt overflowing msize from + // handler. This can be done above, in the main message handler + // loop, by adjusting incoming Tread calls to have a Count that + // won't overflow the msize. + if err := c.ch.WriteFcall(c.ctx, resp); err != nil { if err, ok := err.(net.Error); ok { if err.Timeout() || err.Temporary() { blob - /dev/null blob + ab4b7755ffe5d70016fbd5d03677a5a3a922243b (mode 644) --- /dev/null +++ overflow.go @@ -0,0 +1,49 @@ +package p9p + +import "fmt" + +// Overflow will return a positive number, indicating there was an overflow for +// the error. +func Overflow(err error) int { + if of, ok := err.(overflow); ok { + return of.Size() + } + + // traverse cause, if above fails. + if causal, ok := err.(interface { + Cause() error + }); ok { + return Overflow(causal.Cause()) + } + + return 0 +} + +// overflow is a resolvable error type that can help callers negotiate +// session msize. If this error is encountered, no message was sent. +// +// The return value of `Size()` represents the number of bytes that would have +// been truncated if the message were sent. This IS NOT the optimal buffer size +// for operations like read and write. +// +// In the case of `Twrite`, the caller can Size() from the local size to get an +// optimally size buffer or the write can simply be truncated to `len(buf) - +// err.Size()`. +// +// For the most part, no users of this package should see this error in +// practice. If this escapes the Session interface, it is a bug. +type overflow interface { + Size() int // number of bytes overflowed. +} + +type overflowErr struct { + size int // number of bytes overflowed +} + +func (o overflowErr) Error() string { + return fmt.Sprintf("message overflowed %d bytes", o.size) +} + +func (o overflowErr) Size() int { + return o.size +} blob - 429cdab15ed44a658f612e6be6001cedc23dc7dc blob + eaca8d7018d33da28e0b906606c01aa64d360bfd --- session.go +++ session.go @@ -19,8 +19,17 @@ type Session interface { Clunk(ctx context.Context, fid Fid) error Remove(ctx context.Context, fid Fid) error Walk(ctx context.Context, fid Fid, newfid Fid, names ...string) ([]Qid, error) + + // Read follows the semantics of io.ReaderAt.ReadAtt method except it takes + // a contxt and Fid. Read(ctx context.Context, fid Fid, p []byte, offset int64) (n int, err error) + + // Write follows the semantics of io.WriterAt.WriteAt except takes a context and an Fid. + // + // If n == len(p), no error is returned. + // If n < len(p), io.ErrShortWrite will be returned. Write(ctx context.Context, fid Fid, p []byte, offset int64) (n int, err error) + Open(ctx context.Context, fid Fid, mode Flag) (Qid, uint32, error) Create(ctx context.Context, parent Fid, name string, perm uint32, mode Flag) (Qid, uint32, error) Stat(ctx context.Context, fid Fid) (Dir, error)