Commit Diff


commit - 529e2b2efc36aab8a17e32ba2516fc89d4cbd43a
commit + c74282f87c41aad90bbd9edc99c3c1fe45d65953
blob - 82e719d61f2683fee7c710f0b6f02a3856a2679c
blob + 984958022d113b06920ca64f2defedba90925620
--- channel.go
+++ channel.go
@@ -2,15 +2,19 @@ package p9p
 
 import (
 	"bufio"
+	"context"
 	"encoding/binary"
-	"fmt"
 	"io"
 	"io/ioutil"
 	"log"
 	"net"
 	"time"
+)
 
-	"context"
+const (
+	// channelMessageHeaderSize is the overhead for sending the size of a
+	// message on the wire.
+	channelMessageHeaderSize = 4
 )
 
 // Channel defines the operations necessary to implement a 9p message channel
@@ -114,6 +118,9 @@ func (ch *channel) SetMSize(msize int) {
 }
 
 // ReadFcall reads the next message from the channel into fcall.
+//
+// If the incoming message overflows the msize, Overflow(err) will return
+// nonzero with the number of bytes overflowed.
 func (ch *channel) ReadFcall(ctx context.Context, fcall *Fcall) error {
 	select {
 	case <-ctx.Done():
@@ -140,9 +147,7 @@ func (ch *channel) ReadFcall(ctx context.Context, fcal
 	}
 
 	if n > len(ch.rdbuf) {
-		// TODO(stevvooe): Make this error detectable and respond with error
-		// message.
-		return fmt.Errorf("message too large for buffer: %v > %v ", n, len(ch.rdbuf))
+		return overflowErr{size: n - len(ch.rdbuf)}
 	}
 
 	// clear out the fcall
@@ -151,9 +156,19 @@ func (ch *channel) ReadFcall(ctx context.Context, fcal
 		return err
 	}
 
+	if err := ch.maybeTruncate(fcall); err != nil {
+		return err
+	}
+
 	return nil
 }
 
+// WriteFcall writes the message to the connection.
+//
+// If a message destined for the wire will overflow MSize, an Overflow error
+// may be returned. For Twrite calls, the buffer will simply be truncated to
+// the optimal msize, with the caller detecting this condition with
+// Rwrite.Count.
 func (ch *channel) WriteFcall(ctx context.Context, fcall *Fcall) error {
 	select {
 	case <-ctx.Done():
@@ -172,6 +187,10 @@ func (ch *channel) WriteFcall(ctx context.Context, fca
 		log.Printf("transport: error setting read deadline on %v: %v", ch.conn.RemoteAddr(), err)
 	}
 
+	if err := ch.maybeTruncate(fcall); err != nil {
+		return err
+	}
+
 	p, err := ch.codec.Marshal(fcall)
 	if err != nil {
 		return err
@@ -184,6 +203,86 @@ func (ch *channel) WriteFcall(ctx context.Context, fca
 	return ch.bwr.Flush()
 }
 
+// maybeTruncate will truncate the message to fit into msize on the wire, if
+// possible, or modify the message to ensure the response won't overflow.
+//
+// If the message cannot be truncated, an error will be returned and the
+// message should not be sent.
+//
+// A nil return value means the message can be sent without
+func (ch *channel) maybeTruncate(fcall *Fcall) error {
+
+	// for certain message types, just remove the extra bytes from the data portion.
+	switch msg := fcall.Message.(type) {
+	// TODO(stevvooe): There is one more problematic message type:
+	//
+	// Rread: while we can employ the same truncation fix as Twrite, we
+	// need to make it observable to upstream handlers.
+
+	case MessageTread:
+		// We can rewrite msg.Count so that a return message will be under
+		// msize.  This is more defensive than anything but will ensure that
+		// calls don't fail on sloppy servers.
+
+		// first, craft the shape of the response message
+		resp := newFcall(fcall.Tag, MessageRread{})
+		overflow := uint32(ch.msgmsize(resp)) + msg.Count - uint32(ch.msize)
+
+		if msg.Count < overflow {
+			// Let the bad thing happen; msize too small to even support valid
+			// rewrite. This will result in a Terror from the server-side or
+			// just work.
+			return nil
+		}
+
+		msg.Count -= overflow
+		fcall.Message = msg
+
+		return nil
+	case MessageTwrite:
+		// If we are going to overflow the msize, we need to truncate the write to
+		// appropriate size or throw an error in all other conditions.
+		size := ch.msgmsize(fcall)
+		if size <= ch.msize {
+			return nil
+		}
+
+		// overflow the msize, including the channel message size fields.
+		overflow := size - ch.msize
+
+		if len(msg.Data) < overflow {
+			// paranoid: if msg.Data is not big enough to handle the
+			// overflow, we should get an overflow error. MSize would have
+			// to be way too small to be realistic.
+			return overflowErr{size: overflow}
+		}
+
+		// The truncation is reflected in the return message (Rwrite) by
+		// the server, so we don't need a return value or error condition
+		// to communicate it.
+		msg.Data = msg.Data[:len(msg.Data)-overflow]
+		fcall.Message = msg // since we have a local copy
+
+		return nil
+	default:
+		size := ch.msgmsize(fcall)
+		if size > ch.msize {
+			// overflow the msize, including the channel message size fields.
+			return overflowErr{size: size - ch.msize}
+		}
+
+		return nil
+	}
+
+}
+
+// msgmsize returns the on-wire msize of the Fcall, including the size header.
+// Typically, this can be used to detect whether or not the message overflows
+// the msize buffer.
+func (ch *channel) msgmsize(fcall *Fcall) int {
+	return channelMessageHeaderSize + ch.codec.Size(fcall)
+}
+
 // readmsg reads a 9p message into p from rd, ensuring that all bytes are
 // consumed from the size header. If the size header indicates the message is
 // larger than p, the entire message will be discarded, leaving a truncated
blob - 0a06e0cce5533ef9907949cdef4d3564a63b3408
blob + 5884b06bf362b07fce7bf0404a289c16a7f88531
--- client.go
+++ client.go
@@ -1,6 +1,7 @@
 package p9p
 
 import (
+	"io"
 	"net"
 
 	"context"
@@ -149,7 +150,17 @@ func (c *client) Read(ctx context.Context, fid Fid, p 
 		return 0, ErrUnexpectedMsg
 	}
 
-	return copy(p, rread.Data), nil
+	n = copy(p, rread.Data)
+	switch {
+	case len(rread.Data) == 0:
+		err = io.EOF
+	case n < len(p):
+		// TODO(stevvooe): Technically, we should treat this as an io.EOF.
+		// However, we cannot tell if the short read was due to EOF or due to
+		// truncation.
+	}
+
+	return n, err
 }
 
 func (c *client) Write(ctx context.Context, fid Fid, p []byte, offset int64) (n int, err error) {
@@ -167,7 +178,11 @@ func (c *client) Write(ctx context.Context, fid Fid, p
 		return 0, ErrUnexpectedMsg
 	}
 
-	return int(rwrite.Count), nil
+	if int(rwrite.Count) < len(p) {
+		err = io.ErrShortWrite
+	}
+
+	return int(rwrite.Count), err
 }
 
 func (c *client) Open(ctx context.Context, fid Fid, mode Flag) (Qid, uint32, error) {
blob - /dev/null
blob + fa6137bef1462d6f9d0c392e1a53bf455e63df02 (mode 644)
--- /dev/null
+++ channel_test.go
@@ -0,0 +1,256 @@
+package p9p
+
+import (
+	"bytes"
+	"context"
+	"encoding/binary"
+	"net"
+	"testing"
+	"time"
+)
+
+// TestTwriteOverflow ensures that a Twrite message will have the data field
+// truncated if the msize would be exceeded.
+func TestTwriteOverflow(t *testing.T) {
+	const (
+		msize = 512
+
+		// size[4] Twrite tag[2] fid[4] offset[8] count[4] data[count] | count = 0
+		overhead = 4 + 1 + 2 + 4 + 8 + 4
+	)
+
+	var (
+		ctx  = context.Background()
+		conn = &mockConn{}
+		ch   = NewChannel(conn, msize)
+	)
+
+	for _, testcase := range []struct {
+		name     string
+		overflow int // amount to overflow the message by.
+	}{
+		{
+			name:     "BoundedOverflow",
+			overflow: msize / 2,
+		},
+		{
+			name:     "LargeOverflow",
+			overflow: msize * 3,
+		},
+		{
+			name:     "HeaderOverflow",
+			overflow: overhead,
+		},
+		{
+			name:     "HeaderOffsetOverflow",
+			overflow: overhead - 1,
+		},
+		{
+			name:     "OverflowByOne",
+			overflow: 1,
+		},
+	} {
+
+		t.Run(testcase.name, func(t *testing.T) {
+			var (
+				fcall = overflowMessage(ch.(*channel).codec, msize, testcase.overflow)
+				data  = fcall.Message.(MessageTwrite).Data
+				size  uint32
+			)
+
+			t.Logf("overflow: %v, len(data): %v, expected overflow: %v", testcase.overflow, len(data), overhead+len(data)-msize)
+			conn.buf.Reset()
+			if err := ch.WriteFcall(ctx, fcall); err != nil {
+				t.Fatal(err)
+			}
+
+			if err := binary.Read(bytes.NewReader(conn.buf.Bytes()), binary.LittleEndian, &size); err != nil {
+				t.Fatal(err)
+			}
+
+			if size != msize {
+				t.Fatalf("should have truncated size header: %d != %d", size, msize)
+			}
+
+			if conn.buf.Len() != msize {
+				t.Fatalf("should have truncated message: conn.buf.Len(%v) != msize(%v)", conn.buf.Len(), msize)
+			}
+		})
+	}
+
+}
+
+// TestWriteOverflowError ensures that we return an error in cases when there
+// will certainly be an overflow and it cannot be resolved.
+func TestWriteOverflowError(t *testing.T) {
+	const (
+		msize         = 4
+		overflowMSize = msize + 1
+	)
+
+	var (
+		ctx   = context.Background()
+		conn  = &mockConn{}
+		ch    = NewChannel(conn, msize)
+		data  = bytes.Repeat([]byte{'A'}, 4)
+		fcall = newFcall(1, MessageTwrite{
+			Data: data,
+		})
+		messageSize = 4 + ch.(*channel).codec.Size(fcall)
+	)
+
+	err := ch.WriteFcall(ctx, fcall)
+	if err == nil {
+		t.Fatal("error expected when overflowing message")
+	}
+
+	if Overflow(err) != messageSize-msize {
+		t.Fatalf("overflow should reflect messageSize and msize, %d != %d", Overflow(err), messageSize-msize)
+	}
+}
+
+// TestReadOverflow ensures that messages coming over a network connection do
+// not overflow the msize. Invalid messages will cause `ReadFcall` to return an
+// Overflow error.
+func TestReadFcallOverflow(t *testing.T) {
+	const (
+		msize = 256
+	)
+
+	var (
+		ctx   = context.Background()
+		conn  = &mockConn{}
+		ch    = NewChannel(conn, msize)
+		codec = ch.(*channel).codec
+	)
+
+	for _, testcase := range []struct {
+		name     string
+		overflow int
+	}{
+		{
+			name:     "OverflowByOne",
+			overflow: 1,
+		},
+		{
+			name:     "HeaderOverflow",
+			overflow: overheadMessage(codec, MessageTwrite{}),
+		},
+		{
+			name:     "HeaderOffsetOverflow",
+			overflow: overheadMessage(codec, MessageTwrite{}) - 1,
+		},
+	} {
+		t.Run(testcase.name, func(t *testing.T) {
+			fcall := overflowMessage(codec, msize, testcase.overflow)
+
+			// prepare the raw message
+			p, err := ch.(*channel).codec.Marshal(fcall)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			// "send" the message into the buffer
+			// this message is crafted to overflow the read buffer.
+			if err := sendmsg(&conn.buf, p); err != nil {
+				t.Fatal(err)
+			}
+
+			var incoming Fcall
+			err = ch.ReadFcall(ctx, &incoming)
+			if err == nil {
+				t.Fatal("expected error on fcall")
+			}
+
+			// sanity check to ensure our test code has the right overflow
+			if testcase.overflow != ch.(*channel).msgmsize(fcall)-msize {
+				t.Fatalf("overflow calculation incorrect: %v != %v", testcase.overflow, ch.(*channel).msgmsize(fcall)-msize)
+			}
+
+			if Overflow(err) != testcase.overflow {
+				t.Fatalf("unexpected overflow on error: %v !=%v", Overflow(err), testcase.overflow)
+			}
+		})
+	}
+}
+
+// TestTreadRewrite ensures that messages that whose response would overflow
+// the msize will have be adjusted before sending.
+func TestTreadRewrite(t *testing.T) {
+	const (
+		msize         = 256
+		overflowMSize = msize + 1
+	)
+
+	var (
+		ctx  = context.Background()
+		conn = &mockConn{}
+		ch   = NewChannel(conn, msize)
+		buf  = make([]byte, overflowMSize)
+		// data  = bytes.Repeat([]byte{'A'}, overflowMSize)
+		fcall = newFcall(1, MessageTread{
+			Count: overflowMSize,
+		})
+		responseMSize = ch.(*channel).msgmsize(newFcall(1, MessageRread{
+			Data: buf,
+		}))
+	)
+
+	if err := ch.WriteFcall(ctx, fcall); err != nil {
+		t.Fatal(err)
+	}
+
+	// just read the message off the buffer
+	n, err := readmsg(&conn.buf, buf)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	*fcall = Fcall{}
+	if err := ch.(*channel).codec.Unmarshal(buf[:n], fcall); err != nil {
+		t.Fatal(err)
+	}
+
+	tread, ok := fcall.Message.(MessageTread)
+	if !ok {
+		t.Fatalf("unexpected message: %v", fcall)
+	}
+
+	if tread.Count != overflowMSize-(uint32(responseMSize)-msize) {
+		t.Fatalf("count not rewritten: %v != %v", tread.Count, overflowMSize-(uint32(responseMSize)-msize))
+	}
+}
+
+type mockConn struct {
+	net.Conn
+	buf bytes.Buffer
+}
+
+func (m mockConn) SetWriteDeadline(t time.Time) error { return nil }
+func (m mockConn) SetReadDeadline(t time.Time) error  { return nil }
+
+func (m *mockConn) Write(p []byte) (int, error) {
+	return m.buf.Write(p)
+}
+
+func (m *mockConn) Read(p []byte) (int, error) {
+	return m.buf.Read(p)
+}
+
+func overheadMessage(codec Codec, msg Message) int {
+	return 4 + codec.Size(newFcall(1, msg))
+}
+
+// overflowMessage returns message that overflows the msize by overflow bytes,
+// returning the message size and the fcall.
+func overflowMessage(codec Codec, msize, overflow int) *Fcall {
+	var (
+		overhead = overheadMessage(codec, MessageTwrite{})
+		data     = bytes.Repeat([]byte{'A'}, (msize-overhead)+overflow)
+		fcall    = newFcall(1, MessageTwrite{
+			Data: data,
+		})
+	)
+
+	return fcall
+}
blob - 417aed22503a0e939ccfa2597bec7c69ca8b0ddf
blob + 776e851f7bd9ecbf7ed8508af1878e5dca60d486
--- server.go
+++ server.go
@@ -132,6 +132,9 @@ func (c *conn) serve() error {
 				}
 
 				go func(ctx context.Context, req *Fcall) {
+					// TODO(stevvooe): Re-write incoming Treads so that handler
+					// can always respond with a message of the correct msize.
+
 					var resp *Fcall
 					msg, err := c.handler.Handle(ctx, req.Message)
 					if err != nil {
@@ -208,6 +211,11 @@ func (c *conn) write(responses chan *Fcall) {
 	for {
 		select {
 		case resp := <-responses:
+			// TODO(stevvooe): Correctly protect againt overflowing msize from
+			// handler. This can be done above, in the main message handler
+			// loop, by adjusting incoming Tread calls to have a Count that
+			// won't overflow the msize.
+
 			if err := c.ch.WriteFcall(c.ctx, resp); err != nil {
 				if err, ok := err.(net.Error); ok {
 					if err.Timeout() || err.Temporary() {
blob - /dev/null
blob + ab4b7755ffe5d70016fbd5d03677a5a3a922243b (mode 644)
--- /dev/null
+++ overflow.go
@@ -0,0 +1,49 @@
+package p9p
+
+import "fmt"
+
+// Overflow will return a positive number, indicating there was an overflow for
+// the error.
+func Overflow(err error) int {
+	if of, ok := err.(overflow); ok {
+		return of.Size()
+	}
+
+	// traverse cause, if above fails.
+	if causal, ok := err.(interface {
+		Cause() error
+	}); ok {
+		return Overflow(causal.Cause())
+	}
+
+	return 0
+}
+
+// overflow is a resolvable error type that can help callers negotiate
+// session msize. If this error is encountered, no message was sent.
+//
+// The return value of `Size()` represents the number of bytes that would have
+// been truncated if the message were sent. This IS NOT the optimal buffer size
+// for operations like read and write.
+//
+// In the case of `Twrite`, the caller can Size() from the local size to get an
+// optimally size buffer or the write can simply be truncated to `len(buf) -
+// err.Size()`.
+//
+// For the most part, no users of this package should see this error in
+// practice. If this escapes the Session interface, it is a bug.
+type overflow interface {
+	Size() int // number of bytes overflowed.
+}
+
+type overflowErr struct {
+	size int // number of bytes overflowed
+}
+
+func (o overflowErr) Error() string {
+	return fmt.Sprintf("message overflowed %d bytes", o.size)
+}
+
+func (o overflowErr) Size() int {
+	return o.size
+}
blob - 429cdab15ed44a658f612e6be6001cedc23dc7dc
blob + eaca8d7018d33da28e0b906606c01aa64d360bfd
--- session.go
+++ session.go
@@ -19,8 +19,17 @@ type Session interface {
 	Clunk(ctx context.Context, fid Fid) error
 	Remove(ctx context.Context, fid Fid) error
 	Walk(ctx context.Context, fid Fid, newfid Fid, names ...string) ([]Qid, error)
+
+	// Read follows the semantics of io.ReaderAt.ReadAtt method except it takes
+	// a contxt and Fid.
 	Read(ctx context.Context, fid Fid, p []byte, offset int64) (n int, err error)
+
+	// Write follows the semantics of io.WriterAt.WriteAt except takes a context and an Fid.
+	//
+	// If n == len(p), no error is returned.
+	// If n < len(p), io.ErrShortWrite will be returned.
 	Write(ctx context.Context, fid Fid, p []byte, offset int64) (n int, err error)
+
 	Open(ctx context.Context, fid Fid, mode Flag) (Qid, uint32, error)
 	Create(ctx context.Context, parent Fid, name string, perm uint32, mode Flag) (Qid, uint32, error)
 	Stat(ctx context.Context, fid Fid) (Dir, error)