commit - 529e2b2efc36aab8a17e32ba2516fc89d4cbd43a
commit + 0f5f58bba93f6b0a435fdf59157ef453a2291ba1
blob - 82e719d61f2683fee7c710f0b6f02a3856a2679c
blob + 984958022d113b06920ca64f2defedba90925620
--- channel.go
+++ channel.go
import (
"bufio"
+ "context"
"encoding/binary"
- "fmt"
"io"
"io/ioutil"
"log"
"net"
"time"
+)
- "context"
+const (
+ // channelMessageHeaderSize is the overhead for sending the size of a
+ // message on the wire.
+ channelMessageHeaderSize = 4
)
// Channel defines the operations necessary to implement a 9p message channel
}
// ReadFcall reads the next message from the channel into fcall.
+//
+// If the incoming message overflows the msize, Overflow(err) will return
+// nonzero with the number of bytes overflowed.
func (ch *channel) ReadFcall(ctx context.Context, fcall *Fcall) error {
select {
case <-ctx.Done():
}
if n > len(ch.rdbuf) {
- // TODO(stevvooe): Make this error detectable and respond with error
- // message.
- return fmt.Errorf("message too large for buffer: %v > %v ", n, len(ch.rdbuf))
+ return overflowErr{size: n - len(ch.rdbuf)}
}
// clear out the fcall
return err
}
+ if err := ch.maybeTruncate(fcall); err != nil {
+ return err
+ }
+
return nil
}
+// WriteFcall writes the message to the connection.
+//
+// If a message destined for the wire will overflow MSize, an Overflow error
+// may be returned. For Twrite calls, the buffer will simply be truncated to
+// the optimal msize, with the caller detecting this condition with
+// Rwrite.Count.
func (ch *channel) WriteFcall(ctx context.Context, fcall *Fcall) error {
select {
case <-ctx.Done():
log.Printf("transport: error setting read deadline on %v: %v", ch.conn.RemoteAddr(), err)
}
+ if err := ch.maybeTruncate(fcall); err != nil {
+ return err
+ }
+
p, err := ch.codec.Marshal(fcall)
if err != nil {
return err
return ch.bwr.Flush()
}
+// maybeTruncate will truncate the message to fit into msize on the wire, if
+// possible, or modify the message to ensure the response won't overflow.
+//
+// If the message cannot be truncated, an error will be returned and the
+// message should not be sent.
+//
+// A nil return value means the message can be sent without
+func (ch *channel) maybeTruncate(fcall *Fcall) error {
+
+ // for certain message types, just remove the extra bytes from the data portion.
+ switch msg := fcall.Message.(type) {
+ // TODO(stevvooe): There is one more problematic message type:
+ //
+ // Rread: while we can employ the same truncation fix as Twrite, we
+ // need to make it observable to upstream handlers.
+
+ case MessageTread:
+ // We can rewrite msg.Count so that a return message will be under
+ // msize. This is more defensive than anything but will ensure that
+ // calls don't fail on sloppy servers.
+
+ // first, craft the shape of the response message
+ resp := newFcall(fcall.Tag, MessageRread{})
+ overflow := uint32(ch.msgmsize(resp)) + msg.Count - uint32(ch.msize)
+
+ if msg.Count < overflow {
+ // Let the bad thing happen; msize too small to even support valid
+ // rewrite. This will result in a Terror from the server-side or
+ // just work.
+ return nil
+ }
+
+ msg.Count -= overflow
+ fcall.Message = msg
+
+ return nil
+ case MessageTwrite:
+ // If we are going to overflow the msize, we need to truncate the write to
+ // appropriate size or throw an error in all other conditions.
+ size := ch.msgmsize(fcall)
+ if size <= ch.msize {
+ return nil
+ }
+
+ // overflow the msize, including the channel message size fields.
+ overflow := size - ch.msize
+
+ if len(msg.Data) < overflow {
+ // paranoid: if msg.Data is not big enough to handle the
+ // overflow, we should get an overflow error. MSize would have
+ // to be way too small to be realistic.
+ return overflowErr{size: overflow}
+ }
+
+ // The truncation is reflected in the return message (Rwrite) by
+ // the server, so we don't need a return value or error condition
+ // to communicate it.
+ msg.Data = msg.Data[:len(msg.Data)-overflow]
+ fcall.Message = msg // since we have a local copy
+
+ return nil
+ default:
+ size := ch.msgmsize(fcall)
+ if size > ch.msize {
+ // overflow the msize, including the channel message size fields.
+ return overflowErr{size: size - ch.msize}
+ }
+
+ return nil
+ }
+
+}
+
+// msgmsize returns the on-wire msize of the Fcall, including the size header.
+// Typically, this can be used to detect whether or not the message overflows
+// the msize buffer.
+func (ch *channel) msgmsize(fcall *Fcall) int {
+ return channelMessageHeaderSize + ch.codec.Size(fcall)
+}
+
// readmsg reads a 9p message into p from rd, ensuring that all bytes are
// consumed from the size header. If the size header indicates the message is
// larger than p, the entire message will be discarded, leaving a truncated
blob - 0a06e0cce5533ef9907949cdef4d3564a63b3408
blob + 5884b06bf362b07fce7bf0404a289c16a7f88531
--- client.go
+++ client.go
package p9p
import (
+ "io"
"net"
"context"
return 0, ErrUnexpectedMsg
}
- return copy(p, rread.Data), nil
+ n = copy(p, rread.Data)
+ switch {
+ case len(rread.Data) == 0:
+ err = io.EOF
+ case n < len(p):
+ // TODO(stevvooe): Technically, we should treat this as an io.EOF.
+ // However, we cannot tell if the short read was due to EOF or due to
+ // truncation.
+ }
+
+ return n, err
}
func (c *client) Write(ctx context.Context, fid Fid, p []byte, offset int64) (n int, err error) {
return 0, ErrUnexpectedMsg
}
- return int(rwrite.Count), nil
+ if int(rwrite.Count) < len(p) {
+ err = io.ErrShortWrite
+ }
+
+ return int(rwrite.Count), err
}
func (c *client) Open(ctx context.Context, fid Fid, mode Flag) (Qid, uint32, error) {
blob - /dev/null
blob + fa6137bef1462d6f9d0c392e1a53bf455e63df02 (mode 644)
--- /dev/null
+++ channel_test.go
+package p9p
+
+import (
+ "bytes"
+ "context"
+ "encoding/binary"
+ "net"
+ "testing"
+ "time"
+)
+
+// TestTwriteOverflow ensures that a Twrite message will have the data field
+// truncated if the msize would be exceeded.
+func TestTwriteOverflow(t *testing.T) {
+ const (
+ msize = 512
+
+ // size[4] Twrite tag[2] fid[4] offset[8] count[4] data[count] | count = 0
+ overhead = 4 + 1 + 2 + 4 + 8 + 4
+ )
+
+ var (
+ ctx = context.Background()
+ conn = &mockConn{}
+ ch = NewChannel(conn, msize)
+ )
+
+ for _, testcase := range []struct {
+ name string
+ overflow int // amount to overflow the message by.
+ }{
+ {
+ name: "BoundedOverflow",
+ overflow: msize / 2,
+ },
+ {
+ name: "LargeOverflow",
+ overflow: msize * 3,
+ },
+ {
+ name: "HeaderOverflow",
+ overflow: overhead,
+ },
+ {
+ name: "HeaderOffsetOverflow",
+ overflow: overhead - 1,
+ },
+ {
+ name: "OverflowByOne",
+ overflow: 1,
+ },
+ } {
+
+ t.Run(testcase.name, func(t *testing.T) {
+ var (
+ fcall = overflowMessage(ch.(*channel).codec, msize, testcase.overflow)
+ data = fcall.Message.(MessageTwrite).Data
+ size uint32
+ )
+
+ t.Logf("overflow: %v, len(data): %v, expected overflow: %v", testcase.overflow, len(data), overhead+len(data)-msize)
+ conn.buf.Reset()
+ if err := ch.WriteFcall(ctx, fcall); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := binary.Read(bytes.NewReader(conn.buf.Bytes()), binary.LittleEndian, &size); err != nil {
+ t.Fatal(err)
+ }
+
+ if size != msize {
+ t.Fatalf("should have truncated size header: %d != %d", size, msize)
+ }
+
+ if conn.buf.Len() != msize {
+ t.Fatalf("should have truncated message: conn.buf.Len(%v) != msize(%v)", conn.buf.Len(), msize)
+ }
+ })
+ }
+
+}
+
+// TestWriteOverflowError ensures that we return an error in cases when there
+// will certainly be an overflow and it cannot be resolved.
+func TestWriteOverflowError(t *testing.T) {
+ const (
+ msize = 4
+ overflowMSize = msize + 1
+ )
+
+ var (
+ ctx = context.Background()
+ conn = &mockConn{}
+ ch = NewChannel(conn, msize)
+ data = bytes.Repeat([]byte{'A'}, 4)
+ fcall = newFcall(1, MessageTwrite{
+ Data: data,
+ })
+ messageSize = 4 + ch.(*channel).codec.Size(fcall)
+ )
+
+ err := ch.WriteFcall(ctx, fcall)
+ if err == nil {
+ t.Fatal("error expected when overflowing message")
+ }
+
+ if Overflow(err) != messageSize-msize {
+ t.Fatalf("overflow should reflect messageSize and msize, %d != %d", Overflow(err), messageSize-msize)
+ }
+}
+
+// TestReadOverflow ensures that messages coming over a network connection do
+// not overflow the msize. Invalid messages will cause `ReadFcall` to return an
+// Overflow error.
+func TestReadFcallOverflow(t *testing.T) {
+ const (
+ msize = 256
+ )
+
+ var (
+ ctx = context.Background()
+ conn = &mockConn{}
+ ch = NewChannel(conn, msize)
+ codec = ch.(*channel).codec
+ )
+
+ for _, testcase := range []struct {
+ name string
+ overflow int
+ }{
+ {
+ name: "OverflowByOne",
+ overflow: 1,
+ },
+ {
+ name: "HeaderOverflow",
+ overflow: overheadMessage(codec, MessageTwrite{}),
+ },
+ {
+ name: "HeaderOffsetOverflow",
+ overflow: overheadMessage(codec, MessageTwrite{}) - 1,
+ },
+ } {
+ t.Run(testcase.name, func(t *testing.T) {
+ fcall := overflowMessage(codec, msize, testcase.overflow)
+
+ // prepare the raw message
+ p, err := ch.(*channel).codec.Marshal(fcall)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // "send" the message into the buffer
+ // this message is crafted to overflow the read buffer.
+ if err := sendmsg(&conn.buf, p); err != nil {
+ t.Fatal(err)
+ }
+
+ var incoming Fcall
+ err = ch.ReadFcall(ctx, &incoming)
+ if err == nil {
+ t.Fatal("expected error on fcall")
+ }
+
+ // sanity check to ensure our test code has the right overflow
+ if testcase.overflow != ch.(*channel).msgmsize(fcall)-msize {
+ t.Fatalf("overflow calculation incorrect: %v != %v", testcase.overflow, ch.(*channel).msgmsize(fcall)-msize)
+ }
+
+ if Overflow(err) != testcase.overflow {
+ t.Fatalf("unexpected overflow on error: %v !=%v", Overflow(err), testcase.overflow)
+ }
+ })
+ }
+}
+
+// TestTreadRewrite ensures that messages that whose response would overflow
+// the msize will have be adjusted before sending.
+func TestTreadRewrite(t *testing.T) {
+ const (
+ msize = 256
+ overflowMSize = msize + 1
+ )
+
+ var (
+ ctx = context.Background()
+ conn = &mockConn{}
+ ch = NewChannel(conn, msize)
+ buf = make([]byte, overflowMSize)
+ // data = bytes.Repeat([]byte{'A'}, overflowMSize)
+ fcall = newFcall(1, MessageTread{
+ Count: overflowMSize,
+ })
+ responseMSize = ch.(*channel).msgmsize(newFcall(1, MessageRread{
+ Data: buf,
+ }))
+ )
+
+ if err := ch.WriteFcall(ctx, fcall); err != nil {
+ t.Fatal(err)
+ }
+
+ // just read the message off the buffer
+ n, err := readmsg(&conn.buf, buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ *fcall = Fcall{}
+ if err := ch.(*channel).codec.Unmarshal(buf[:n], fcall); err != nil {
+ t.Fatal(err)
+ }
+
+ tread, ok := fcall.Message.(MessageTread)
+ if !ok {
+ t.Fatalf("unexpected message: %v", fcall)
+ }
+
+ if tread.Count != overflowMSize-(uint32(responseMSize)-msize) {
+ t.Fatalf("count not rewritten: %v != %v", tread.Count, overflowMSize-(uint32(responseMSize)-msize))
+ }
+}
+
+type mockConn struct {
+ net.Conn
+ buf bytes.Buffer
+}
+
+func (m mockConn) SetWriteDeadline(t time.Time) error { return nil }
+func (m mockConn) SetReadDeadline(t time.Time) error { return nil }
+
+func (m *mockConn) Write(p []byte) (int, error) {
+ return m.buf.Write(p)
+}
+
+func (m *mockConn) Read(p []byte) (int, error) {
+ return m.buf.Read(p)
+}
+
+func overheadMessage(codec Codec, msg Message) int {
+ return 4 + codec.Size(newFcall(1, msg))
+}
+
+// overflowMessage returns message that overflows the msize by overflow bytes,
+// returning the message size and the fcall.
+func overflowMessage(codec Codec, msize, overflow int) *Fcall {
+ var (
+ overhead = overheadMessage(codec, MessageTwrite{})
+ data = bytes.Repeat([]byte{'A'}, (msize-overhead)+overflow)
+ fcall = newFcall(1, MessageTwrite{
+ Data: data,
+ })
+ )
+
+ return fcall
+}
blob - 417aed22503a0e939ccfa2597bec7c69ca8b0ddf
blob + 776e851f7bd9ecbf7ed8508af1878e5dca60d486
--- server.go
+++ server.go
}
go func(ctx context.Context, req *Fcall) {
+ // TODO(stevvooe): Re-write incoming Treads so that handler
+ // can always respond with a message of the correct msize.
+
var resp *Fcall
msg, err := c.handler.Handle(ctx, req.Message)
if err != nil {
for {
select {
case resp := <-responses:
+ // TODO(stevvooe): Correctly protect againt overflowing msize from
+ // handler. This can be done above, in the main message handler
+ // loop, by adjusting incoming Tread calls to have a Count that
+ // won't overflow the msize.
+
if err := c.ch.WriteFcall(c.ctx, resp); err != nil {
if err, ok := err.(net.Error); ok {
if err.Timeout() || err.Temporary() {
blob - /dev/null
blob + ab4b7755ffe5d70016fbd5d03677a5a3a922243b (mode 644)
--- /dev/null
+++ overflow.go
+package p9p
+
+import "fmt"
+
+// Overflow will return a positive number, indicating there was an overflow for
+// the error.
+func Overflow(err error) int {
+ if of, ok := err.(overflow); ok {
+ return of.Size()
+ }
+
+ // traverse cause, if above fails.
+ if causal, ok := err.(interface {
+ Cause() error
+ }); ok {
+ return Overflow(causal.Cause())
+ }
+
+ return 0
+}
+
+// overflow is a resolvable error type that can help callers negotiate
+// session msize. If this error is encountered, no message was sent.
+//
+// The return value of `Size()` represents the number of bytes that would have
+// been truncated if the message were sent. This IS NOT the optimal buffer size
+// for operations like read and write.
+//
+// In the case of `Twrite`, the caller can Size() from the local size to get an
+// optimally size buffer or the write can simply be truncated to `len(buf) -
+// err.Size()`.
+//
+// For the most part, no users of this package should see this error in
+// practice. If this escapes the Session interface, it is a bug.
+type overflow interface {
+ Size() int // number of bytes overflowed.
+}
+
+type overflowErr struct {
+ size int // number of bytes overflowed
+}
+
+func (o overflowErr) Error() string {
+ return fmt.Sprintf("message overflowed %d bytes", o.size)
+}
+
+func (o overflowErr) Size() int {
+ return o.size
+}
blob - 429cdab15ed44a658f612e6be6001cedc23dc7dc
blob + eaca8d7018d33da28e0b906606c01aa64d360bfd
--- session.go
+++ session.go
Clunk(ctx context.Context, fid Fid) error
Remove(ctx context.Context, fid Fid) error
Walk(ctx context.Context, fid Fid, newfid Fid, names ...string) ([]Qid, error)
+
+ // Read follows the semantics of io.ReaderAt.ReadAtt method except it takes
+ // a contxt and Fid.
Read(ctx context.Context, fid Fid, p []byte, offset int64) (n int, err error)
+
+ // Write follows the semantics of io.WriterAt.WriteAt except takes a context and an Fid.
+ //
+ // If n == len(p), no error is returned.
+ // If n < len(p), io.ErrShortWrite will be returned.
Write(ctx context.Context, fid Fid, p []byte, offset int64) (n int, err error)
+
Open(ctx context.Context, fid Fid, mode Flag) (Qid, uint32, error)
Create(ctx context.Context, parent Fid, name string, perm uint32, mode Flag) (Qid, uint32, error)
Stat(ctx context.Context, fid Fid) (Dir, error)