Commit Diff


commit - 8a7ec69711074b12b52a63f9eb61ce8cc82425bb
commit + 499f8c59005e11c0b5590adabe8b660c4a4cf1cb
blob - 1e7044f40483d25674fb40acc8b2a71b900925c7
blob + 5d526a8dcd50d0bf26806e5a1c3e3521a209b8b2
--- client.go
+++ client.go
@@ -14,14 +14,15 @@ import (
 type client struct {
 	conn     net.Conn
 	tags     *tagPool
-	requests chan *fcallRequest
+	requests chan fcallRequest
+	closed   chan struct{}
 }
 
 // NewSession returns a session using the connection.
 func NewSession(conn net.Conn) (Session, error) {
 	return &client{
 		conn: conn,
-	}
+	}, nil
 }
 
 var _ Session = &client{}
@@ -70,13 +71,12 @@ func (c *client) WStat(context.Context, Fid, Dir) erro
 	panic("not implemented")
 }
 
-func (c *client) Version(ctx context.Context, msize int32, version string) (int32, string, error) {
+func (c *client) Version(ctx context.Context, msize uint32, version string) (uint32, string, error) {
 	fcall := &Fcall{
-		Type: TVersion,
-		Tag:  tag,
+		Type: Tversion,
 		Message: MessageVersion{
-			MSize:   msize,
-			Version: Version,
+			MSize:   uint32(msize),
+			Version: version,
 		},
 	}
 
@@ -87,12 +87,18 @@ func (c *client) Version(ctx context.Context, msize in
 
 	mv, ok := resp.Message.(*MessageVersion)
 	if !ok {
-		return fmt.Errorf("invalid rpc response for version message: %v", resp)
+		return 0, "", fmt.Errorf("invalid rpc response for version message: %v", resp)
 	}
 
 	return mv.MSize, mv.Version, nil
 }
 
+func (c *client) flush(ctx context.Context, tag Tag) error {
+	// TODO(stevvooe): We need to fire and forget flush messages when a call
+	// context gets cancelled.
+	panic("not implemented")
+}
+
 // send dispatches the fcall.
 func (c *client) send(ctx context.Context, fc *Fcall) (*Fcall, error) {
 	fc.Tag = c.tags.Get()
@@ -111,7 +117,7 @@ func (c *client) send(ctx context.Context, fc *Fcall) 
 
 	// wait for the response.
 	select {
-	case <-closed:
+	case <-c.closed:
 		return nil, ErrClosed
 	case <-ctx.Done():
 		return nil, ctx.Err()
@@ -132,7 +138,7 @@ func newFcallRequest(ctx context.Context, fc *Fcall) f
 		ctx:      ctx,
 		fcall:    fc,
 		response: make(chan *Fcall, 1),
-		err:      make(chan err, 1),
+		err:      make(chan error, 1),
 	}
 }
 
@@ -147,7 +153,7 @@ func (c *client) handle() {
 
 	// loop to read messages off of the connection
 	go func() {
-		r := bufio.NewReader(c.conn)
+		dec := &decoder{bufio.NewReader(c.conn)}
 
 	loop:
 		for {
@@ -160,7 +166,7 @@ func (c *client) handle() {
 			}
 
 			fc := new(Fcall)
-			if err := read9p(r, fc); err != nil {
+			if err := dec.decode(fc); err != nil {
 				switch err := err.(type) {
 				case net.Error:
 					if err.Timeout() || err.Temporary() {
@@ -172,7 +178,7 @@ func (c *client) handle() {
 			}
 
 			select {
-			case <-closed:
+			case <-c.closed:
 				return
 			case responses <- fc:
 			}
@@ -180,14 +186,14 @@ func (c *client) handle() {
 
 	}()
 
-	w := bufio.NewWriter(c.conn)
+	enc := &encoder{bufio.NewWriter(c.conn)}
 
 	for {
 		select {
 		case <-c.closed:
 			return
 		case req := <-c.requests:
-			outstanding[req.fcall.Tag] = req
+			outstanding[req.fcall.Tag] = &req
 
 			// use deadline to set write deadline for this request.
 			deadline, ok := req.ctx.Deadline()
@@ -199,7 +205,7 @@ func (c *client) handle() {
 				log.Println("error setting write deadline: %v", err)
 			}
 
-			if err := write9p(w, req.fcall); err != nil {
+			if err := enc.encode(req.fcall); err != nil {
 				delete(outstanding, req.fcall.Tag)
 				req.err <- err
 			}
blob - /dev/null
blob + 081bacb99c65174f24ce8108cfa6eebc7219e27e (mode 644)
--- /dev/null
+++ encoding.go
@@ -0,0 +1,204 @@
+package p9pnew
+
+import (
+	"encoding/binary"
+	"fmt"
+	"io"
+	"reflect"
+)
+
+// NOTE(stevvooe): This file covers 9p encoding and decoding (despite just
+// being called encoding).
+
+type encoder struct {
+	wr io.Writer
+}
+
+func (e *encoder) encode(vs ...interface{}) error {
+	for _, v := range vs {
+		switch v := v.(type) {
+		case *string:
+			if err := e.encode(*v); err != nil {
+				return err
+			}
+		case string:
+			// implement string[s] encoding
+			if err := binary.Write(e.wr, binary.LittleEndian, uint16(len(v))); err != nil {
+				return err
+			}
+
+			_, err := io.WriteString(e.wr, v)
+			if err != nil {
+				return err
+			}
+		case Message:
+			// walk the fields of the message to get the total size. we just
+			// use the field order from the message struct. We may add tag
+			// ignoring if needed.
+			elements, err := fields9p(v)
+			if err != nil {
+				return err
+			}
+
+			if err := e.encode(elements...); err != nil {
+				return err
+			}
+		case Fcall:
+			if err := e.encode(&v); err != nil {
+				return err
+			}
+		case *Fcall:
+			if err := e.encode(size9p(v), v.Type, v.Tag, v.Message); err != nil {
+				return err
+			}
+		default:
+			if err := binary.Write(e.wr, binary.LittleEndian, v); err != nil {
+				return err
+			}
+		}
+	}
+
+	return nil
+}
+
+type decoder struct {
+	rd io.Reader
+}
+
+// read9p extracts values from rd and unmarshals them to the targets of vs.
+func (d *decoder) decode(vs ...interface{}) error {
+	for _, v := range vs {
+		switch v := v.(type) {
+		case *string:
+			var ll uint16
+
+			// implement string[s] encoding
+			if err := binary.Read(d.rd, binary.LittleEndian, &ll); err != nil {
+				return err
+			}
+
+			b := make([]byte, ll)
+
+			n, err := io.ReadFull(d.rd, b)
+			if err != nil {
+				return err
+			}
+
+			if n != int(ll) {
+				return fmt.Errorf("unexpected string length")
+			}
+
+			*v = string(b)
+		case *Fcall:
+			var size uint32
+			if err := d.decode(&size, &v.Type, &v.Tag); err != nil {
+				return err
+			}
+
+			var err error
+			v.Message, err = newMessage(v.Type)
+			if err != nil {
+				return err
+			}
+
+			if err := d.decode(v.Message); err != nil {
+				return err
+			}
+		case Message:
+			elements, err := fields9p(v)
+			if err != nil {
+				return err
+			}
+
+			if err := d.decode(elements...); err != nil {
+				return err
+			}
+		default:
+			if err := binary.Read(d.rd, binary.LittleEndian, v); err != nil {
+				return err
+			}
+		}
+	}
+
+	return nil
+}
+
+// size9p calculates the projected size of the values in vs when encoded into
+// 9p binary protocol. If an element or elements are not valid for 9p encoded,
+// the value 0 will be used for the size. The error will be detected when
+// encoding.
+func size9p(vs ...interface{}) uint32 {
+	var s uint32
+	for _, v := range vs {
+		if v == nil {
+			continue
+		}
+
+		switch v := v.(type) {
+		case *string:
+			s += uint32(binary.Size(uint16(0)) + len(*v))
+		case string:
+			s += uint32(binary.Size(uint16(0)) + len(v))
+		case Message:
+			// walk the fields of the message to get the total size. we just
+			// use the field order from the message struct. We may add tag
+			// ignoring if needed.
+			elements, err := fields9p(v)
+			if err != nil {
+				// BUG(stevvooe): The options here are to return 0, panic or
+				// make this return an error. Ideally, we make it safe to
+				// return 0 and have the rest of the package do the right
+				// thing. For now, we do this, but may want to panic until
+				// things are stable.
+				return 0
+			}
+
+			s += size9p(elements...)
+		case Fcall:
+			s += size9p(v.Type, v.Tag, v.Message)
+		case *Fcall:
+			// Calculates the total size of the fcall, excluding the size
+			// header, which is handled exernally. The result of
+			// (*Fcall).MarshalBinary will have len(p) == the result of this
+			// branch. The value should be
+			s += size9p(v.Type, v.Tag, v.Message)
+		default:
+			s += uint32(binary.Size(v))
+		}
+	}
+
+	return s
+}
+
+// fields9p lists the settable fields from a struct type for reading and
+// writing. We are using a lot of reflection here for fairly static
+// serialization but we can replace this in the future with generated code if
+// performance is an issue.
+func fields9p(v interface{}) ([]interface{}, error) {
+	rv := reflect.Indirect(reflect.ValueOf(v))
+
+	if rv.Kind() != reflect.Struct {
+		return nil, fmt.Errorf("cannot extract fields from non-struct: %v", rv)
+	}
+
+	var elements []interface{}
+	for i := 0; i < rv.NumField(); i++ {
+		f := rv.Field(i)
+
+		if !f.CanInterface() {
+			return nil, fmt.Errorf("can't interface: %v", f)
+		}
+
+		if !f.CanSet() {
+			return nil, fmt.Errorf("cannot set %v", f)
+		}
+
+		if f.CanAddr() {
+			f = f.Addr()
+		}
+
+		elements = append(elements, f.Interface())
+	}
+
+	return elements, nil
+}
blob - 73b68dd64117abec8a2328e35d03d23921bd9f42
blob + d550f4532b8a7fa29594fd215f549c129367981e
--- fcall.go
+++ fcall.go
@@ -1,14 +1,7 @@
 package p9pnew
 
-import (
-	"bytes"
-	"encoding/binary"
-	"fmt"
-	"io"
+import "fmt"
 
-	"encoding"
-)
-
 type FcallType uint8
 
 const (
@@ -43,46 +36,99 @@ const (
 	Tmax
 )
 
+func (fct FcallType) String() string {
+	switch fct {
+	case Tversion:
+		return "Tversion"
+	case Rversion:
+		return "Rversion"
+	case Tauth:
+		return "Tauth"
+	case Rauth:
+		return "Rauth"
+	case Tattach:
+		return "Tattach"
+	case Rattach:
+		return "Rattach"
+	case Terror:
+		// invalid.
+		return "Terror"
+	case Rerror:
+		return "Rerror"
+	case Tflush:
+		return "Tflush"
+	case Rflush:
+		return "Rflush"
+	case Twalk:
+		return "Twalk"
+	case Rwalk:
+		return "Rwalk"
+	case Topen:
+		return "Topen"
+	case Ropen:
+		return "Ropen"
+	case Tcreate:
+		return "Tcreate"
+	case Rcreate:
+		return "Rcreate"
+	case Tread:
+		return "Tread"
+	case Rread:
+		return "Rread"
+	case Twrite:
+		return "Twrite"
+	case Rwrite:
+		return "Rwrite"
+	case Tclunk:
+		return "Tclunk"
+	case Rclunk:
+		return "Rclunk"
+	case Tremove:
+		return "Tremote"
+	case Rremove:
+		return "Rremove"
+	case Tstat:
+		return "Tstat"
+	case Rstat:
+		return "Rstat"
+	case Twstat:
+		return "Twstat"
+	case Rwstat:
+		return "Rwstat"
+	default:
+		return "Tunknown"
+	}
+}
+
 type Fcall struct {
-	Type    Type
+	Type    FcallType
 	Tag     Tag
 	Message Message
 }
 
-const (
-	fcallHeaderSize = 4 /*size*/ + 1 /*type*/
-)
-
-func (fc *Fcall) Size() int {
-	return fcallHeaderSize + fc.Message.Size()
+func (fc Fcall) String() string {
+	return fmt.Sprintf("%8d %v(%v) %v", size9p(fc), fc.Type, fc.Tag, fc.Message)
 }
 
-func (fc *Fcall) MarshalBinary() ([]byte, error) {
-	mp, err := fc.Message.MarshalBinary()
-	if err != nil {
-		return nil, err
-	}
+type Message interface {
+	// Size() uint32
 
-	b := bytes.NewBuffer(make([]byte, 0, fc.Size()))
-	if err := write9p(b, fc.Size(), fc.Tag, mp); err != nil {
-		return nil, err
-	}
+	// NOTE(stevvooe): The binary marshal approach isn't particularly nice to
+	// generating garbage. Consider using an append model, once we have the
+	// messages worked out.
+	// encoding.BinaryMarshaler
+	// encoding.BinaryUnmarshaler
 
-	return b.Bytes(), nil
+	message9p()
 }
 
-func (fc *Fcall) UnmarshalBinary(p []data) error {
-	var (
-		r = bytes.NewReader(p)
-	)
-
-	if err := read9p(r, &fc.Type, &fc.Tag); err != nil {
-		return err
-	}
-
-	switch fc.Type {
+// newMessage returns a new instance of the message based on the Fcall type.
+func newMessage(typ FcallType) (Message, error) {
+	// NOTE(stevvooe): This is a nasty bit of code but makes the transport
+	// fairly simple to implement.
+	switch typ {
 	case Tversion, Rversion:
-		fc.Message = &MessageVersion{}
+		return &MessageVersion{}, nil
 	case Tauth:
 
 	case Rauth:
@@ -96,9 +142,9 @@ func (fc *Fcall) UnmarshalBinary(p []data) error {
 	case Rerror:
 
 	case Tflush:
-
+		return &MessageFlush{}, nil
 	case Rflush:
-
+		return nil, nil // No message body for this response.
 	case Twalk:
 
 	case Rwalk:
@@ -134,18 +180,14 @@ func (fc *Fcall) UnmarshalBinary(p []data) error {
 	case Twstat:
 
 	case Rwstat:
+	default:
+		return nil, fmt.Errorf("unknown message type: %v", typ)
 
 	}
 
-	return fc.Message.UnmarshalBinary(p[len(p)-r.Len():])
+	return nil, fmt.Errorf("unknown message")
 }
 
-type Message interface {
-	Size() int
-	encoding.BinaryMarshaler
-	encoding.BinaryUnmarshaler
-}
-
 // MessageVersion encodes the message body for Tversion and Rversion RPC
 // calls. The body is identical in both directions.
 type MessageVersion struct {
@@ -153,109 +195,14 @@ type MessageVersion struct {
 	Version string
 }
 
-func (mv MessageVersion) Size() int {
-	return 4 + 2 + len(mv.Version)
+func (MessageVersion) message9p() {}
+func (mv MessageVersion) String() string {
+	return fmt.Sprintf("msize=%v version=%v", mv.MSize, mv.Version)
 }
 
-func (mv MessageVersion) MarshalBinary() ([]byte, error) {
-	b := bytes.NewBuffer(make([]byte, 0, mv.Size()))
-
-	if err := write9p(b, mv.MSize, mv.Version); err != nil {
-		return nil, err
-	}
-
-	return b.Bytes(), nil
+// MessageFlush handles the content for the Tflush message type.
+type MessageFlush struct {
+	Oldtag Tag
 }
 
-// write9p implements serialization for base types.
-func write9p(w io.Writer, vs ...interface{}) error {
-	for _, v := range vs {
-		switch v := v.(type) {
-		case string:
-			// implement string[s] encoding
-			if err := binary.Write(w, binary.LittleEndian, uint16(len(v))); err != nil {
-				return err
-			}
-
-			_, err := io.WriteString(w, s)
-			if err != nil {
-
-				return err
-			}
-		case *Fcall:
-			if err := write9p(w, v.Size()); err != nil {
-				return err
-			}
-			p, err := v.MarshalBinary()
-			if err != nil {
-				return err
-			}
-
-			n, err := w.Write(p)
-			if err != nil {
-				return err
-			}
-
-			if n != len(p) {
-				return io.ErrShortWrite
-			}
-
-			return nil
-		default:
-			if err := binary.Write(w, binary.LittleEndian, v); err != nil {
-				return err
-			}
-		}
-	}
-
-	return nil
-}
-
-// read9p extracts values from rd and unmarshals them to the targets of vs.
-func read9p(rd io.Reader, vs ...interface{}) error {
-	for _, v := range vs {
-		switch v := v.(type) {
-		case *string:
-			var ll uint16
-
-			// implement string[s] encoding
-			if err := binary.Read(r, binary.LittleEndian, &ll); err != nil {
-				return err
-			}
-
-			b := make([]byte, ll)
-
-			n, err := io.ReadFull(b)
-			if err != nil {
-				return err
-			}
-
-			if n != int(ll) {
-				return fmt.Errorf("unexpected string length")
-			}
-
-			*v = string(b)
-		case *Fcall:
-			var size uint32
-			if err := read9p(buffered, &size); err != nil {
-				return err
-			}
-
-			p := make([]byte, size)
-			n, err := io.ReadFull(p)
-			if err != nil {
-				return err
-			}
-
-			if n != size {
-				return fmt.Errorf("error reading fcall: short read")
-			}
-
-			return v.UnmarshalBinary(p)
-		default:
-			if err := binary.Read(r, binary.LittleEndian, v); err != nil {
-				return err
-			}
-		}
-	}
-}
+func (MessageFlush) message9p() {}
blob - /dev/null
blob + 074f520e9f040f5f136f3dec00b7db9a713be4f1 (mode 644)
--- /dev/null
+++ encoding_test.go
@@ -0,0 +1,72 @@
+package p9pnew
+
+import (
+	"bytes"
+	"reflect"
+	"testing"
+)
+
+func TestEncodeDecode(t *testing.T) {
+	for _, testcase := range []struct {
+		description string
+		target      interface{}
+		marshaled   []byte
+	}{
+		{
+			description: "string",
+			target:      "asdf",
+			marshaled:   []byte{0x4, 0x0, 0x61, 0x73, 0x64, 0x66},
+		},
+		{
+			description: "Tversion fcall",
+			target: &Fcall{
+				Type: Tversion,
+				Tag:  2255,
+				Message: &MessageVersion{
+					MSize:   uint32(1024),
+					Version: "9PTEST",
+				},
+			},
+			marshaled: []byte{0xf, 0x0, 0x0, 0x0, 0x64, 0xcf, 0x8, 0x0, 0x4, 0x0, 0x0, 0x6, 0x0, 0x39, 0x50, 0x54, 0x45, 0x53, 0x54},
+		},
+	} {
+		t.Logf("target under test: %v", testcase.target)
+		fatalf := func(format string, args ...interface{}) {
+			t.Fatalf(testcase.description+": "+format, args...)
+		}
+
+		var b bytes.Buffer
+
+		enc := &encoder{&b}
+		dec := &decoder{&b}
+
+		if err := enc.encode(testcase.target); err != nil {
+			fatalf("error writing fcall: %v", err)
+		}
+
+		if !bytes.Equal(b.Bytes(), testcase.marshaled) {
+			fatalf("unexpected bytes for fcall: %#v != %#v", b.Bytes(), testcase.marshaled)
+		}
+
+		var v interface{}
+		targetType := reflect.TypeOf(testcase.target)
+
+		if targetType.Kind() == reflect.Ptr {
+			v = reflect.New(targetType.Elem()).Interface()
+		} else {
+			v = reflect.New(targetType).Interface()
+		}
+
+		if err := dec.decode(v); err != nil {
+			fatalf("error reading fcall: %v", err)
+		}
+
+		if targetType.Kind() != reflect.Ptr {
+			v = reflect.Indirect(reflect.ValueOf(v)).Interface()
+		}
+
+		if !reflect.DeepEqual(v, testcase.target) {
+			fatalf("not equal: %#v != %#v", v, testcase.target)
+		}
+	}
+}
blob - df87804033610f46a568197526e072fb5292a7c9
blob + f1bae724b08d7812eafd714ad375fcaa413085ed
--- logging.go
+++ logging.go
@@ -1,3 +1,5 @@
+// +build ignore
+
 package p9pnew
 
 import (
blob - 287a302388f8565d8ce9c6fcb628fa154a9d1a7a
blob + 032f07f44d7148b2ac5d6755d3afae997b5606f5
--- session.go
+++ session.go
@@ -33,7 +33,7 @@ type Session interface {
 	// TODO(stevvooe): The version message affects a lot of protocol behavior.
 	// Consider hiding it behind the implementation, letting the version get
 	// negotiated. The API user should still be able to query it.
-	Version(ctx context.Context, msize int32, version string) (int32, string, error)
+	Version(ctx context.Context, msize uint32, version string) (uint32, string, error)
 }
 
 func Dial(addr string) (Session, error) {
blob - 23ab0c4ea3426ea68d3967bfc611167ad034ef77
blob + 524e7b4ec21e4fcfb010d59e7ba62e07fb3a78f1
--- types.go
+++ types.go
@@ -1,11 +1,7 @@
 package p9pnew
 
-import (
-	"encoding"
+import "time"
 
-	"time"
-)
-
 const (
 	NOFID = ^Fid(0)
 	NOTAG = ^Tag(0)
@@ -75,60 +71,3 @@ type Dir struct {
 
 //
 type Tag uint16
-
-type FcallType uint8
-
-const (
-	FcallTypeTversion FcallType = iota + 100
-	FcallTypeRversion
-	FcallTypeTauth
-	FcallTypeRauth
-	FcallTypeTattach
-	FcallTypeRattach
-	FcallTypeTerror
-	FcallTypeRerror
-	FcallTypeTflush
-	FcallTypeRflush
-	FcallTypeTwalk
-	FcallTypeRwalk
-	FcallTypeTopen
-	FcallTypeRopen
-	FcallTypeTcreate
-	FcallTypeRcreate
-	FcallTypeTread
-	FcallTypeRread
-	FcallTypeTwrite
-	FcallTypeRwrite
-	FcallTypeTclunk
-	FcallTypeRclunk
-	FcallTypeTremove
-	FcallTypeRremove
-	FcallTypeTstat
-	FcallTypeRstat
-	FcallTypeTwstat
-	FcallTypeRwstat
-	FcallTypeTmax
-)
-
-type Fcall struct {
-	Type    Type
-	Fid     Fid
-	Tag     Tag
-	Message Message
-}
-
-type Message interface {
-	Size() int
-	encoding.BinaryMarshaler
-	encoding.BinaryUnmarshaler
-}
-
-type MessageVersion struct {
-	MSize   uint32
-	Version string
-}
-
-func (mv MessageVersion) MarshalBinary() ([]byte, error) {
-
-	encoding.BinaryMarshaler
-}
blob - /dev/null
blob + 1074f83dc94bc6e9447cf920226f37465cd3d61c (mode 644)
--- /dev/null
+++ server.go
@@ -0,0 +1,8 @@
+package p9pnew
+
+import "net"
+
+// Serve the 9p session over the provided network connection.
+func Serve(conn net.Conn, session Session) error {
+	panic("not implemented")
+}