Commit Diff


commit - dce371f63bd86f69e7164ea19cc60540efce050f
commit + 269e4d4b21f59c17f781c377325982e8cd369857
blob - c74ca513e2ac3883f696294c0d4f0852bcc092f6
blob + 806f4d297f3ace65b8671cfc7eee19c46b83db41
--- channel.go
+++ channel.go
@@ -77,7 +77,6 @@ type channel struct {
 	closed chan struct{}
 	msize  int
 	rdbuf  []byte
-	wrbuf  []byte
 }
 
 func newChannel(conn net.Conn, codec Codec, msize int) *channel {
@@ -89,7 +88,6 @@ func newChannel(conn net.Conn, codec Codec, msize int)
 		closed: make(chan struct{}),
 		msize:  msize,
 		rdbuf:  make([]byte, msize),
-		wrbuf:  make([]byte, msize),
 	}
 }
 
@@ -107,12 +105,10 @@ func (ch *channel) SetMSize(msize int) {
 	if msize < len(ch.rdbuf) {
 		// just change the cap
 		ch.rdbuf = ch.rdbuf[:msize]
-		ch.wrbuf = ch.wrbuf[:msize]
 		return
 	}
 
 	ch.rdbuf = make([]byte, msize)
-	ch.wrbuf = make([]byte, msize)
 }
 
 // ReadFcall reads the next message from the channel into fcall.
@@ -175,13 +171,11 @@ func (ch *channel) WriteFcall(ctx context.Context, fca
 		log.Printf("transport: error setting read deadline on %v: %v", ch.conn.RemoteAddr(), err)
 	}
 
-	n, err := ch.codec.Marshal(ch.wrbuf, fcall)
+	p, err := ch.codec.Marshal(fcall)
 	if err != nil {
 		return err
 	}
 
-	p := ch.wrbuf[:n]
-
 	if err := sendmsg(ch.bwr, p); err != nil {
 		return err
 	}
blob - 2946cb9286fea0999eccc58cd13bb1147f8ca13d
blob + bfce22d1890ddf9139f07f5af8b9f279299a429c
--- encoding.go
+++ encoding.go
@@ -11,26 +11,20 @@ import (
 	"time"
 )
 
-// EncodeDir is just a helper for encoding directories until we export the
-// encoder and decoder.
-func EncodeDir(wr io.Writer, d *Dir) error {
-	enc := &encoder{wr}
+// Codec defines the interface for encoding and decoding of 9p types.
+// Unsupported types will throw an error.
+type Codec interface {
+	// Unmarshal from data into the value pointed to by v.
+	Unmarshal(data []byte, v interface{}) error
 
-	return enc.encode(d)
+	// Marshal the value v into a byte slice.
+	Marshal(v interface{}) ([]byte, error)
 }
 
-// DecodeDir is just a helper for decoding directories until we export the
-// encoder and decoder.
-func DecodeDir(rd io.Reader, d *Dir) error {
-	dec := &decoder{rd}
-	return dec.decode(d)
+func NewCodec() Codec {
+	return codec9p{}
 }
 
-type Codec interface {
-	Unmarshal(data []byte, v interface{}) error
-	Marshal(data []byte, v interface{}) (n int, err error)
-}
-
 type codec9p struct{}
 
 func (c codec9p) Unmarshal(data []byte, v interface{}) error {
@@ -38,26 +32,17 @@ func (c codec9p) Unmarshal(data []byte, v interface{})
 	return dec.decode(v)
 }
 
-func (c codec9p) Marshal(data []byte, v interface{}) (n int, err error) {
-	n = int(size9p(v))
+func (c codec9p) Marshal(v interface{}) ([]byte, error) {
+	var b bytes.Buffer
+	enc := &encoder{&b}
 
-	buf := bytes.NewBuffer(data[:0])
-	enc := &encoder{buf}
-
 	if err := enc.encode(v); err != nil {
-		return buf.Len(), nil
+		return nil, err
 	}
 
-	if len(data) < buf.Len() {
-		return len(data), io.ErrShortBuffer
-	}
-
-	return buf.Len(), nil
+	return b.Bytes(), nil
 }
 
-// NOTE(stevvooe): This file covers 9p encoding and decoding (despite just
-// being called encoding).
-
 type encoder struct {
 	wr io.Writer
 }
@@ -65,37 +50,25 @@ type encoder struct {
 func (e *encoder) encode(vs ...interface{}) error {
 	for _, v := range vs {
 		switch v := v.(type) {
-		case *[]string:
-			if err := e.encode(*v); err != nil {
+		case uint8, uint16, uint32, uint64, FcallType, Tag, QType, Fid,
+			*uint8, *uint16, *uint32, *uint64, *FcallType, *Tag, *QType, *Fid:
+			if err := binary.Write(e.wr, binary.LittleEndian, v); err != nil {
 				return err
 			}
-		case []string:
-			if err := e.encode(uint16(len(v))); err != nil {
+		case []byte:
+			if err := e.encode(uint32(len(v))); err != nil {
 				return err
 			}
 
-			var elements []interface{}
-			for _, e := range v {
-				elements = append(elements, e)
-			}
-
-			if err := e.encode(elements...); err != nil {
+			if err := binary.Write(e.wr, binary.LittleEndian, v); err != nil {
 				return err
 			}
-		case *[]byte:
-			if err := e.encode(uint32(len(*v))); err != nil {
-				return err
-			}
 
+		case *[]byte:
 			if err := e.encode(*v); err != nil {
 				return err
 			}
-		case *string:
-			if err := e.encode(*v); err != nil {
-				return err
-			}
 		case string:
-			// implement string[s] encoding
 			if err := binary.Write(e.wr, binary.LittleEndian, uint16(len(v))); err != nil {
 				return err
 			}
@@ -104,37 +77,30 @@ func (e *encoder) encode(vs ...interface{}) error {
 			if err != nil {
 				return err
 			}
-		case *Dir:
-			// NOTE(stevvooe): See bugs in http://man.cat-v.org/plan_9/5/stat
-			// to make sense of this. The field has been included here but we
-			// need to make sure to double emit it for Rstat.
-
-			elements, err := fields9p(v)
-			if err != nil {
+		case *string:
+			if err := e.encode(*v); err != nil {
 				return err
 			}
 
-			elements = append([]interface{}{uint16(size9p(elements...))}, elements...)
-
-			if err := e.encode(elements...); err != nil {
+		case []string:
+			if err := e.encode(uint16(len(v))); err != nil {
 				return err
 			}
-		case Message:
-			elements, err := fields9p(v)
-			if err != nil {
-				return err
-			}
 
-			switch v.(type) {
-			case MessageRstat, *MessageRstat:
-				// encode an size header in front of the dir field
-				elements = append([]interface{}{uint16(size9p(elements...))}, elements...)
+			for _, m := range v {
+				if err := e.encode(m); err != nil {
+					return err
+				}
 			}
-
-			if err := e.encode(elements...); err != nil {
+		case *[]string:
+			if err := e.encode(*v); err != nil {
 				return err
 			}
-		case *Qid:
+		case time.Time:
+			if err := e.encode(uint32(v.Unix())); err != nil {
+				return err
+			}
+		case *time.Time:
 			if err := e.encode(*v); err != nil {
 				return err
 			}
@@ -142,7 +108,7 @@ func (e *encoder) encode(vs ...interface{}) error {
 			if err := e.encode(v.Type, v.Version, v.Path); err != nil {
 				return err
 			}
-		case *[]Qid:
+		case *Qid:
 			if err := e.encode(*v); err != nil {
 				return err
 			}
@@ -159,24 +125,53 @@ func (e *encoder) encode(vs ...interface{}) error {
 			if err := e.encode(elements...); err != nil {
 				return err
 			}
-		case time.Time:
-			if err := e.encode(uint32(v.Unix())); err != nil {
+		case *[]Qid:
+			if err := e.encode(*v); err != nil {
 				return err
 			}
-		case *time.Time:
+		case Dir:
+			elements, err := fields9p(v)
+			if err != nil {
+				return err
+			}
+
+			if err := e.encode(uint16(size9p(elements...))); err != nil {
+				return err
+			}
+
+			if err := e.encode(elements...); err != nil {
+				return err
+			}
+		case *Dir:
 			if err := e.encode(*v); err != nil {
 				return err
 			}
 		case Fcall:
-			if err := e.encode(&v); err != nil {
+			if err := e.encode(v.Type, v.Tag, v.Message); err != nil {
 				return err
 			}
 		case *Fcall:
-			if err := e.encode(v.Type, v.Tag, v.Message); err != nil {
+			if err := e.encode(*v); err != nil {
 				return err
 			}
-		default:
-			if err := binary.Write(e.wr, binary.LittleEndian, v); err != nil {
+		case Message:
+			elements, err := fields9p(v)
+			if err != nil {
+				return err
+			}
+
+			switch v.(type) {
+			case MessageRstat, *MessageRstat:
+				// NOTE(stevvooe): Prepend size preceeding Dir. See bugs in
+				// http://man.cat-v.org/plan_9/5/stat to make sense of this.
+				// The field has been included here but we need to make sure
+				// to double emit it for Rstat.
+				if err := e.encode(uint16(size9p(elements...))); err != nil {
+					return err
+				}
+			}
+
+			if err := e.encode(elements...); err != nil {
 				return err
 			}
 		}
@@ -193,6 +188,22 @@ type decoder struct {
 func (d *decoder) decode(vs ...interface{}) error {
 	for _, v := range vs {
 		switch v := v.(type) {
+		case *uint8, *uint16, *uint32, *uint64, *FcallType, *Tag, *QType, *Fid:
+			if err := binary.Read(d.rd, binary.LittleEndian, v); err != nil {
+				return err
+			}
+		case *[]byte:
+			var ll uint32
+
+			if err := d.decode(&ll); err != nil {
+				return err
+			}
+
+			*v = make([]byte, int(ll))
+
+			if err := binary.Read(d.rd, binary.LittleEndian, v); err != nil {
+				return err
+			}
 		case *string:
 			var ll uint16
 
@@ -205,7 +216,6 @@ func (d *decoder) decode(vs ...interface{}) error {
 
 			n, err := io.ReadFull(d.rd, b)
 			if err != nil {
-				log.Println("readfull failed:", err, ll, n)
 				return err
 			}
 
@@ -230,36 +240,17 @@ func (d *decoder) decode(vs ...interface{}) error {
 			if err := d.decode(elements...); err != nil {
 				return err
 			}
-		case *[]byte:
-			var ll uint32
-
-			if err := d.decode(&ll); err != nil {
+		case *time.Time:
+			var epoch uint32
+			if err := d.decode(&epoch); err != nil {
 				return err
 			}
 
-			*v = make([]byte, int(ll))
-
-			if err := binary.Read(d.rd, binary.LittleEndian, v); err != nil {
+			*v = time.Unix(int64(epoch), 0).UTC()
+		case *Qid:
+			if err := d.decode(&v.Type, &v.Version, &v.Path); err != nil {
 				return err
 			}
-		case *Fcall:
-			if err := d.decode(&v.Type, &v.Tag); err != nil {
-				return err
-			}
-
-			var err error
-			v.Message, err = newMessage(v.Type)
-			if err != nil {
-				log.Printf("unknown message type %#v", v.Type)
-				return err
-			}
-
-			// take the address of v.Message from the struct and encode into
-			// that.
-
-			if err := d.decode(v.Message); err != nil {
-				return err
-			}
 		case *[]Qid:
 			var ll uint16
 
@@ -276,13 +267,6 @@ func (d *decoder) decode(vs ...interface{}) error {
 			if err := d.decode(elements...); err != nil {
 				return err
 			}
-		case *time.Time:
-			var epoch uint32
-			if err := d.decode(&epoch); err != nil {
-				return err
-			}
-
-			*v = time.Unix(int64(epoch), 0).UTC()
 		case *Dir:
 			var ll uint16
 
@@ -306,19 +290,41 @@ func (d *decoder) decode(vs ...interface{}) error {
 			dec := &decoder{bytes.NewReader(b)}
 
 			if err := dec.decode(elements...); err != nil {
+				return err
+			}
+		case *Fcall:
+			if err := d.decode(&v.Type, &v.Tag); err != nil {
+				return err
+			}
+
+			message, err := newMessage(v.Type)
+			if err != nil {
+				return err
+			}
+
+			// NOTE(stevvooe): We do a little pointer dance to allocate the
+			// new type, write to it, then assign it back to the interface as
+			// a concrete type, avoiding a pointer (the interface) to a
+			// pointer.
+			rv := reflect.New(reflect.TypeOf(message))
+			if err := d.decode(rv.Interface()); err != nil {
 				return err
 			}
+
+			v.Message = rv.Elem().Interface()
 		case Message:
 			elements, err := fields9p(v)
 			if err != nil {
 				return err
 			}
 
-			// special case twstat and rstat for size fields. See bugs in
-			// http://man.cat-v.org/plan_9/5/stat to make sense of this.
 			switch v.(type) {
 			case *MessageRstat, MessageRstat:
-				// decode extra size header for stat structure.
+				// NOTE(stevvooe): Consume extra size preceeding Dir. See bugs
+				// in http://man.cat-v.org/plan_9/5/stat to make sense of
+				// this. The field has been included here but we need to make
+				// sure to double emit it for Rstat. decode extra size header
+				// for stat structure.
 				var ll uint16
 				if err := d.decode(&ll); err != nil {
 					return err
@@ -328,14 +334,6 @@ func (d *decoder) decode(vs ...interface{}) error {
 			if err := d.decode(elements...); err != nil {
 				return err
 			}
-		case *Qid:
-			if err := d.decode(&v.Type, &v.Version, &v.Path); err != nil {
-				return err
-			}
-		default:
-			if err := binary.Read(d.rd, binary.LittleEndian, v); err != nil {
-				return err
-			}
 		}
 	}
 
@@ -354,24 +352,32 @@ func size9p(vs ...interface{}) uint32 {
 		}
 
 		switch v := v.(type) {
-		case *string:
-			s += uint32(binary.Size(uint16(0)) + len(*v))
+		case uint8, uint16, uint32, uint64, FcallType, Tag, QType, Fid,
+			*uint8, *uint16, *uint32, *uint64, *FcallType, *Tag, *QType, *Fid:
+			s += uint32(binary.Size(v))
+		case []byte:
+			s += uint32(binary.Size(uint32(0)) + len(v))
+		case *[]byte:
+			s += size9p(uint32(0), *v)
 		case string:
 			s += uint32(binary.Size(uint16(0)) + len(v))
-		case *[]string:
+		case *string:
 			s += size9p(*v)
 		case []string:
 			s += size9p(uint16(0))
-			elements := make([]interface{}, len(v))
-			for i := range elements {
-				elements[i] = v[i]
-			}
 
-			s += size9p(elements...)
-		case *[]byte:
-			s += size9p(uint32(0), *v)
-		case *[]Qid:
+			for _, sv := range v {
+				s += size9p(sv)
+			}
+		case *[]string:
 			s += size9p(*v)
+		case time.Time, *time.Time:
+			// BUG(stevvooe): Y2038 is coming.
+			s += size9p(uint32(0))
+		case Qid:
+			s += size9p(v.Type, v.Version, v.Path)
+		case *Qid:
+			s += size9p(*v)
 		case []Qid:
 			s += size9p(uint16(0))
 			elements := make([]interface{}, len(v))
@@ -379,10 +385,9 @@ func size9p(vs ...interface{}) uint32 {
 				elements[i] = &v[i]
 			}
 			s += size9p(elements...)
-		case time.Time, *time.Time:
-			s += size9p(uint32(0))
-		case Qid:
-			s += size9p(v.Type, v.Version, v.Path)
+		case *[]Qid:
+			s += size9p(*v)
+
 		case Dir:
 			// walk the fields of the message to get the total size. we just
 			// use the field order from the message struct. We may add tag
@@ -398,8 +403,13 @@ func size9p(vs ...interface{}) uint32 {
 			}
 
 			s += size9p(elements...) + size9p(uint16(0))
+		case *Dir:
+			s += size9p(*v)
+		case Fcall:
+			s += size9p(v.Type, v.Tag, v.Message)
+		case *Fcall:
+			s += size9p(*v)
 		case Message:
-
 			// special case twstat and rstat for size fields. See bugs in
 			// http://man.cat-v.org/plan_9/5/stat to make sense of this.
 			switch v.(type) {
@@ -421,16 +431,6 @@ func size9p(vs ...interface{}) uint32 {
 			}
 
 			s += size9p(elements...)
-		case *Qid:
-			s += size9p(*v)
-		case *Dir:
-			s += size9p(*v)
-		case Fcall:
-			s += size9p(v.Type, v.Tag, v.Message)
-		case *Fcall:
-			s += size9p(*v)
-		default:
-			s += uint32(binary.Size(v))
 		}
 	}
 
blob - 1e3302a3d55f8a6c5b2a695cbcdaac3e5db3b1ad
blob + 4229e6c22b31eaa77298b5daf75173d338d778aa
--- encoding_test.go
+++ encoding_test.go
@@ -9,12 +9,23 @@ import (
 )
 
 func TestEncodeDecode(t *testing.T) {
+	codec := NewCodec()
 	for _, testcase := range []struct {
 		description string
 		target      interface{}
 		marshaled   []byte
 	}{
 		{
+			description: "uint8",
+			target:      uint8('U'),
+			marshaled:   []byte{0x55},
+		},
+		{
+			description: "uint16",
+			target:      uint16(0x5544),
+			marshaled:   []byte{0x44, 0x55},
+		},
+		{
 			description: "string",
 			target:      "asdf",
 			marshaled:   []byte{0x4, 0x0, 0x61, 0x73, 0x64, 0x66},
@@ -28,14 +39,25 @@ func TestEncodeDecode(t *testing.T) {
 				0x4, 0x0, 0x71, 0x77, 0x65, 0x72,
 				0x4, 0x0, 0x7a, 0x78, 0x63, 0x76},
 		},
+		{
+			description: "Qid",
+			target: Qid{
+				Type:    QTDIR,
+				Version: 0x10203040,
+				Path:    0x1020304050607080},
+			marshaled: []byte{
+				byte(QTDIR),            // qtype
+				0x40, 0x30, 0x20, 0x10, // version
+				0x80, 0x70, 0x60, 0x50, 0x40, 0x30, 0x20, 0x10, // path
+			},
+		},
 		// Dir
-		// Qid
 		{
 			description: "Tversion fcall",
 			target: &Fcall{
 				Type: Tversion,
 				Tag:  2255,
-				Message: &MessageTversion{
+				Message: MessageTversion{
 					MSize:   uint32(1024),
 					Version: "9PTEST",
 				},
@@ -49,7 +71,7 @@ func TestEncodeDecode(t *testing.T) {
 			target: &Fcall{
 				Type: Rversion,
 				Tag:  2255,
-				Message: &MessageRversion{
+				Message: MessageRversion{
 					MSize:   uint32(1024),
 					Version: "9PTEST",
 				},
@@ -63,7 +85,7 @@ func TestEncodeDecode(t *testing.T) {
 			target: &Fcall{
 				Type: Twalk,
 				Tag:  5666,
-				Message: &MessageTwalk{
+				Message: MessageTwalk{
 					Fid:    1010,
 					Newfid: 1011,
 					Wnames: []string{"a", "b", "c"},
@@ -81,7 +103,7 @@ func TestEncodeDecode(t *testing.T) {
 			target: &Fcall{
 				Type: Rwalk,
 				Tag:  5556,
-				Message: &MessageRwalk{
+				Message: MessageRwalk{
 					Qids: []Qid{
 						Qid{
 							Type:    QTDIR,
@@ -105,7 +127,7 @@ func TestEncodeDecode(t *testing.T) {
 			target: &Fcall{
 				Type: Rread,
 				Tag:  5556,
-				Message: &MessageRread{
+				Message: MessageRread{
 					Data: []byte("a lot of byte data"),
 				},
 			},
@@ -119,7 +141,7 @@ func TestEncodeDecode(t *testing.T) {
 			target: &Fcall{
 				Type: Rstat,
 				Tag:  5556,
-				Message: &MessageRstat{
+				Message: MessageRstat{
 					Stat: Dir{
 						Type: ^uint16(0),
 						Dev:  ^uint32(0),
@@ -166,24 +188,18 @@ func TestEncodeDecode(t *testing.T) {
 				0x41, 0x20, 0x73, 0x65, 0x72, 0x69, 0x6f, 0x75, 0x73, 0x20, 0x65, 0x72, 0x72, 0x6f, 0x72},
 		},
 	} {
-		t.Logf("target under test: %v", testcase.target)
+		t.Logf("target under test: %#v %T", testcase.target, testcase.target)
 		fatalf := func(format string, args ...interface{}) {
 			t.Fatalf(testcase.description+": "+format, args...)
 		}
 
-		t.Logf("expecting message of %v bytes", len(testcase.marshaled))
-
-		var b bytes.Buffer
-
-		enc := &encoder{&b}
-		dec := &decoder{&b}
-
-		if err := enc.encode(testcase.target); err != nil {
+		p, err := codec.Marshal(testcase.target)
+		if err != nil {
 			fatalf("error writing fcall: %v", err)
 		}
 
-		if !bytes.Equal(b.Bytes(), testcase.marshaled) {
-			fatalf("unexpected bytes for fcall: \n%#v != \n%#v", b.Bytes(), testcase.marshaled)
+		if !bytes.Equal(p, testcase.marshaled) {
+			fatalf("unexpected bytes for fcall: \n%#v != \n%#v", p, testcase.marshaled)
 		}
 
 		// check that size9p is working correctly
@@ -200,7 +216,7 @@ func TestEncodeDecode(t *testing.T) {
 			v = reflect.New(targetType).Interface()
 		}
 
-		if err := dec.decode(v); err != nil {
+		if err := codec.Unmarshal(p, v); err != nil {
 			fatalf("error reading: %v", err)
 		}
 
blob - 80d2e7e1fd009f0babc3639130e1519b72bdd533
blob + 5ce3d9ce659404d591a1c55302d4c95a0e675ef9
--- fcall.go
+++ fcall.go
@@ -108,14 +108,15 @@ type Fcall struct {
 
 func newFcall(msg Message) *Fcall {
 	var tag Tag
+	mtype := messageType(msg)
 
-	switch msg.Type() {
+	switch mtype {
 	case Tversion, Rversion:
 		tag = NOTAG
 	}
 
 	return &Fcall{
-		Type:    msg.Type(),
+		Type:    mtype,
 		Tag:     tag,
 		Message: msg,
 	}
@@ -128,7 +129,7 @@ func newErrorFcall(tag Tag, err error) *Fcall {
 	case MessageRerror:
 		msg = v
 	case *MessageRerror:
-		msg = v
+		msg = *v
 	default:
 		msg = MessageRerror{Ename: v.Error()}
 	}
@@ -142,225 +143,4 @@ func newErrorFcall(tag Tag, err error) *Fcall {
 
 func (fc *Fcall) String() string {
 	return fmt.Sprintf("%v(%v) %v", fc.Type, fc.Tag, string9p(fc.Message))
-}
-
-type Message interface {
-	// Type indicates the Fcall type of the message. This must match
-	// Fcall.Type.
-	Type() FcallType
-}
-
-// newMessage returns a new instance of the message based on the Fcall type.
-func newMessage(typ FcallType) (Message, error) {
-	// NOTE(stevvooe): This is a nasty bit of code but makes the transport
-	// fairly simple to implement.
-	switch typ {
-	case Tversion:
-		return &MessageTversion{}, nil
-	case Rversion:
-		return &MessageRversion{}, nil
-	case Tauth:
-		return &MessageTauth{}, nil
-	case Rauth:
-		return &MessageRauth{}, nil
-	case Tattach:
-		return &MessageTattach{}, nil
-	case Rattach:
-		return &MessageRattach{}, nil
-	case Rerror:
-		return &MessageRerror{}, nil
-	case Tflush:
-		return &MessageTflush{}, nil
-	case Rflush:
-		return &MessageRflush{}, nil // No message body for this response.
-	case Twalk:
-		return &MessageTwalk{}, nil
-	case Rwalk:
-		return &MessageRwalk{}, nil
-	case Topen:
-		return &MessageTopen{}, nil
-	case Ropen:
-		return &MessageRopen{}, nil
-	case Tcreate:
-		return &MessageTcreate{}, nil
-	case Rcreate:
-		return &MessageRcreate{}, nil
-	case Tread:
-		return &MessageTread{}, nil
-	case Rread:
-		return &MessageRread{}, nil
-	case Twrite:
-		return &MessageTwrite{}, nil
-	case Rwrite:
-		return &MessageRwrite{}, nil
-	case Tclunk:
-		return &MessageTclunk{}, nil
-	case Rclunk:
-		return &MessageRclunk{}, nil // no response body
-	case Tremove:
-
-	case Rremove:
-
-	case Tstat:
-
-	case Rstat:
-		return &MessageRstat{}, nil
-	case Twstat:
-
-	case Rwstat:
-
-	}
-
-	return nil, fmt.Errorf("unknown message type")
-}
-
-// MessageVersion encodes the message body for Tversion and Rversion RPC
-// calls. The body is identical in both directions.
-type MessageTversion struct {
-	MSize   uint32
-	Version string
-}
-
-type MessageRversion struct {
-	MSize   uint32
-	Version string
-}
-
-type MessageTauth struct {
-	Afid  Fid
-	Uname string
-	Aname string
-}
-
-type MessageRauth struct {
-	Qid Qid
-}
-
-type MessageRerror struct {
-	Ename string
-}
-
-func (e MessageRerror) Error() string {
-	return fmt.Sprintf("9p: %v", e.Ename)
-}
-
-type MessageTflush struct {
-	Oldtag Tag
-}
-
-type MessageRflush struct{}
-
-type MessageTattach struct {
-	Fid   Fid
-	Afid  Fid
-	Uname string
-	Aname string
 }
-
-type MessageRattach struct {
-	Qid Qid
-}
-
-type MessageTwalk struct {
-	Fid    Fid
-	Newfid Fid
-	Wnames []string
-}
-
-type MessageRwalk struct {
-	Qids []Qid
-}
-
-type MessageTopen struct {
-	Fid  Fid
-	Mode uint8
-}
-
-type MessageRopen struct {
-	Qid    Qid
-	IOUnit uint32
-}
-
-type MessageTcreate struct {
-	Fid  Fid
-	Name string
-	Perm uint32
-	Mode uint8
-}
-
-type MessageRcreate struct {
-	Qid    Qid
-	IOUnit uint32
-}
-
-type MessageTread struct {
-	Fid    Fid
-	Offset uint64
-	Count  uint32
-}
-
-type MessageRread struct {
-	Data []byte
-}
-
-type MessageTwrite struct {
-	Fid    Fid
-	Offset uint64
-	Data   []byte
-}
-
-type MessageRwrite struct {
-	Count uint32
-}
-
-type MessageTclunk struct {
-	Fid Fid
-}
-
-type MessageRclunk struct{}
-
-type MessageTremove struct {
-	Fid Fid
-}
-
-type MessageRremove struct{}
-
-type MessageTstat struct {
-	Fid Fid
-}
-
-type MessageRstat struct {
-	Stat Dir
-}
-
-type MessageTwstat struct {
-	Fid  Fid
-	Stat Dir
-}
-
-func (MessageTversion) Type() FcallType { return Tversion }
-func (MessageRversion) Type() FcallType { return Rversion }
-func (MessageTauth) Type() FcallType    { return Tauth }
-func (MessageRauth) Type() FcallType    { return Rauth }
-func (MessageTflush) Type() FcallType   { return Tflush }
-func (MessageRflush) Type() FcallType   { return Rflush }
-func (MessageRerror) Type() FcallType   { return Rerror }
-func (MessageTattach) Type() FcallType  { return Tattach }
-func (MessageRattach) Type() FcallType  { return Rattach }
-func (MessageTwalk) Type() FcallType    { return Twalk }
-func (MessageRwalk) Type() FcallType    { return Rwalk }
-func (MessageTopen) Type() FcallType    { return Topen }
-func (MessageRopen) Type() FcallType    { return Ropen }
-func (MessageTcreate) Type() FcallType  { return Tcreate }
-func (MessageRcreate) Type() FcallType  { return Rcreate }
-func (MessageTread) Type() FcallType    { return Tread }
-func (MessageRread) Type() FcallType    { return Rread }
-func (MessageTwrite) Type() FcallType   { return Twrite }
-func (MessageRwrite) Type() FcallType   { return Rwrite }
-func (MessageTclunk) Type() FcallType   { return Tclunk }
-func (MessageRclunk) Type() FcallType   { return Rclunk }
-func (MessageTremove) Type() FcallType  { return Tremove }
-func (MessageRremove) Type() FcallType  { return Rremove }
-func (MessageTstat) Type() FcallType    { return Tstat }
-func (MessageRstat) Type() FcallType    { return Rstat }
-func (MessageTwstat) Type() FcallType   { return Twstat }
blob - /dev/null
blob + a3c74c87f6bcdac3e0cbeefd612a8da2a49002b6 (mode 644)
--- /dev/null
+++ messages.go
@@ -0,0 +1,258 @@
+package p9pnew
+
+import "fmt"
+
+// Message represents the target of an fcall.
+type Message interface{}
+
+// newMessage returns a new instance of the message based on the Fcall type.
+func newMessage(typ FcallType) (Message, error) {
+	switch typ {
+	case Tversion:
+		return MessageTversion{}, nil
+	case Rversion:
+		return MessageRversion{}, nil
+	case Tauth:
+		return MessageTauth{}, nil
+	case Rauth:
+		return MessageRauth{}, nil
+	case Tattach:
+		return MessageTattach{}, nil
+	case Rattach:
+		return MessageRattach{}, nil
+	case Rerror:
+		return MessageRerror{}, nil
+	case Tflush:
+		return MessageTflush{}, nil
+	case Rflush:
+		return MessageRflush{}, nil // No message body for this response.
+	case Twalk:
+		return MessageTwalk{}, nil
+	case Rwalk:
+		return MessageRwalk{}, nil
+	case Topen:
+		return MessageTopen{}, nil
+	case Ropen:
+		return MessageRopen{}, nil
+	case Tcreate:
+		return MessageTcreate{}, nil
+	case Rcreate:
+		return MessageRcreate{}, nil
+	case Tread:
+		return MessageTread{}, nil
+	case Rread:
+		return MessageRread{}, nil
+	case Twrite:
+		return MessageTwrite{}, nil
+	case Rwrite:
+		return MessageRwrite{}, nil
+	case Tclunk:
+		return MessageTclunk{}, nil
+	case Rclunk:
+		return MessageRclunk{}, nil // no response body
+	case Tremove:
+		return MessageTremove{}, nil
+	case Rremove:
+		return MessageRremove{}, nil
+	case Tstat:
+		return MessageTstat{}, nil
+	case Rstat:
+		return MessageRstat{}, nil
+	case Twstat:
+		return MessageTwstat{}, nil
+	case Rwstat:
+		return MessageRwstat{}, nil
+	}
+
+	return nil, fmt.Errorf("unknown message type")
+}
+
+func messageType(m Message) FcallType {
+	switch v := m.(type) {
+	case MessageTversion:
+		return Tversion
+	case MessageRversion:
+		return Rversion
+	case MessageTauth:
+		return Tauth
+	case MessageRauth:
+		return Rauth
+	case MessageTflush:
+		return Tflush
+	case MessageRflush:
+		return Rflush
+	case MessageRerror:
+		return Rerror
+	case MessageTattach:
+		return Tattach
+	case MessageRattach:
+		return Rattach
+	case MessageTwalk:
+		return Twalk
+	case MessageRwalk:
+		return Rwalk
+	case MessageTopen:
+		return Topen
+	case MessageRopen:
+		return Ropen
+	case MessageTcreate:
+		return Tcreate
+	case MessageRcreate:
+		return Rcreate
+	case MessageTread:
+		return Tread
+	case MessageRread:
+		return Rread
+	case MessageTwrite:
+		return Twrite
+	case MessageRwrite:
+		return Rwrite
+	case MessageTclunk:
+		return Tclunk
+	case MessageRclunk:
+		return Rclunk
+	case MessageTremove:
+		return Tremove
+	case MessageRremove:
+		return Rremove
+	case MessageTstat:
+		return Tstat
+	case MessageRstat:
+		return Rstat
+	case MessageTwstat:
+		return Twstat
+	case MessageRwstat:
+		return Rwstat
+	case error:
+		return Rerror
+	default:
+		// NOTE(stevvooe): This is considered a programming error.
+		panic(fmt.Sprintf("unsupported message type: %T", v))
+	}
+}
+
+// MessageVersion encodes the message body for Tversion and Rversion RPC
+// calls. The body is identical in both directions.
+type MessageTversion struct {
+	MSize   uint32
+	Version string
+}
+
+type MessageRversion struct {
+	MSize   uint32
+	Version string
+}
+
+type MessageTauth struct {
+	Afid  Fid
+	Uname string
+	Aname string
+}
+
+type MessageRauth struct {
+	Qid Qid
+}
+
+type MessageRerror struct {
+	Ename string
+}
+
+func (e MessageRerror) Error() string {
+	return fmt.Sprintf("9p: %v", e.Ename)
+}
+
+type MessageTflush struct {
+	Oldtag Tag
+}
+
+type MessageRflush struct{}
+
+type MessageTattach struct {
+	Fid   Fid
+	Afid  Fid
+	Uname string
+	Aname string
+}
+
+type MessageRattach struct {
+	Qid Qid
+}
+
+type MessageTwalk struct {
+	Fid    Fid
+	Newfid Fid
+	Wnames []string
+}
+
+type MessageRwalk struct {
+	Qids []Qid
+}
+
+type MessageTopen struct {
+	Fid  Fid
+	Mode uint8
+}
+
+type MessageRopen struct {
+	Qid    Qid
+	IOUnit uint32
+}
+
+type MessageTcreate struct {
+	Fid  Fid
+	Name string
+	Perm uint32
+	Mode uint8
+}
+
+type MessageRcreate struct {
+	Qid    Qid
+	IOUnit uint32
+}
+
+type MessageTread struct {
+	Fid    Fid
+	Offset uint64
+	Count  uint32
+}
+
+type MessageRread struct {
+	Data []byte
+}
+
+type MessageTwrite struct {
+	Fid    Fid
+	Offset uint64
+	Data   []byte
+}
+
+type MessageRwrite struct {
+	Count uint32
+}
+
+type MessageTclunk struct {
+	Fid Fid
+}
+
+type MessageRclunk struct{}
+
+type MessageTremove struct {
+	Fid Fid
+}
+
+type MessageRremove struct{}
+
+type MessageTstat struct {
+	Fid Fid
+}
+
+type MessageRstat struct {
+	Stat Dir
+}
+
+type MessageTwstat struct {
+	Fid  Fid
+	Stat Dir
+}
+
+type MessageRwstat struct{}