commit 499f8c59005e11c0b5590adabe8b660c4a4cf1cb from: Stephen J Day date: Tue Oct 27 04:16:51 2015 UTC fs/p9p/new: Lock down encode/decode for 9p Among other things, this adds support for 9p encoding and decoding. We now have fully reflective message marshaling and unmarshaling. Other aspects of this commit include the code actually compiling and movement towards a testable interface. Signed-off-by: Stephen J Day commit - 8a7ec69711074b12b52a63f9eb61ce8cc82425bb commit + 499f8c59005e11c0b5590adabe8b660c4a4cf1cb blob - 1e7044f40483d25674fb40acc8b2a71b900925c7 blob + 5d526a8dcd50d0bf26806e5a1c3e3521a209b8b2 --- client.go +++ client.go @@ -14,14 +14,15 @@ import ( type client struct { conn net.Conn tags *tagPool - requests chan *fcallRequest + requests chan fcallRequest + closed chan struct{} } // NewSession returns a session using the connection. func NewSession(conn net.Conn) (Session, error) { return &client{ conn: conn, - } + }, nil } var _ Session = &client{} @@ -70,13 +71,12 @@ func (c *client) WStat(context.Context, Fid, Dir) erro panic("not implemented") } -func (c *client) Version(ctx context.Context, msize int32, version string) (int32, string, error) { +func (c *client) Version(ctx context.Context, msize uint32, version string) (uint32, string, error) { fcall := &Fcall{ - Type: TVersion, - Tag: tag, + Type: Tversion, Message: MessageVersion{ - MSize: msize, - Version: Version, + MSize: uint32(msize), + Version: version, }, } @@ -87,12 +87,18 @@ func (c *client) Version(ctx context.Context, msize in mv, ok := resp.Message.(*MessageVersion) if !ok { - return fmt.Errorf("invalid rpc response for version message: %v", resp) + return 0, "", fmt.Errorf("invalid rpc response for version message: %v", resp) } return mv.MSize, mv.Version, nil } +func (c *client) flush(ctx context.Context, tag Tag) error { + // TODO(stevvooe): We need to fire and forget flush messages when a call + // context gets cancelled. + panic("not implemented") +} + // send dispatches the fcall. func (c *client) send(ctx context.Context, fc *Fcall) (*Fcall, error) { fc.Tag = c.tags.Get() @@ -111,7 +117,7 @@ func (c *client) send(ctx context.Context, fc *Fcall) // wait for the response. select { - case <-closed: + case <-c.closed: return nil, ErrClosed case <-ctx.Done(): return nil, ctx.Err() @@ -132,7 +138,7 @@ func newFcallRequest(ctx context.Context, fc *Fcall) f ctx: ctx, fcall: fc, response: make(chan *Fcall, 1), - err: make(chan err, 1), + err: make(chan error, 1), } } @@ -147,7 +153,7 @@ func (c *client) handle() { // loop to read messages off of the connection go func() { - r := bufio.NewReader(c.conn) + dec := &decoder{bufio.NewReader(c.conn)} loop: for { @@ -160,7 +166,7 @@ func (c *client) handle() { } fc := new(Fcall) - if err := read9p(r, fc); err != nil { + if err := dec.decode(fc); err != nil { switch err := err.(type) { case net.Error: if err.Timeout() || err.Temporary() { @@ -172,7 +178,7 @@ func (c *client) handle() { } select { - case <-closed: + case <-c.closed: return case responses <- fc: } @@ -180,14 +186,14 @@ func (c *client) handle() { }() - w := bufio.NewWriter(c.conn) + enc := &encoder{bufio.NewWriter(c.conn)} for { select { case <-c.closed: return case req := <-c.requests: - outstanding[req.fcall.Tag] = req + outstanding[req.fcall.Tag] = &req // use deadline to set write deadline for this request. deadline, ok := req.ctx.Deadline() @@ -199,7 +205,7 @@ func (c *client) handle() { log.Println("error setting write deadline: %v", err) } - if err := write9p(w, req.fcall); err != nil { + if err := enc.encode(req.fcall); err != nil { delete(outstanding, req.fcall.Tag) req.err <- err } blob - /dev/null blob + 081bacb99c65174f24ce8108cfa6eebc7219e27e (mode 644) --- /dev/null +++ encoding.go @@ -0,0 +1,204 @@ +package p9pnew + +import ( + "encoding/binary" + "fmt" + "io" + "reflect" +) + +// NOTE(stevvooe): This file covers 9p encoding and decoding (despite just +// being called encoding). + +type encoder struct { + wr io.Writer +} + +func (e *encoder) encode(vs ...interface{}) error { + for _, v := range vs { + switch v := v.(type) { + case *string: + if err := e.encode(*v); err != nil { + return err + } + case string: + // implement string[s] encoding + if err := binary.Write(e.wr, binary.LittleEndian, uint16(len(v))); err != nil { + return err + } + + _, err := io.WriteString(e.wr, v) + if err != nil { + return err + } + case Message: + // walk the fields of the message to get the total size. we just + // use the field order from the message struct. We may add tag + // ignoring if needed. + elements, err := fields9p(v) + if err != nil { + return err + } + + if err := e.encode(elements...); err != nil { + return err + } + case Fcall: + if err := e.encode(&v); err != nil { + return err + } + case *Fcall: + if err := e.encode(size9p(v), v.Type, v.Tag, v.Message); err != nil { + return err + } + default: + if err := binary.Write(e.wr, binary.LittleEndian, v); err != nil { + return err + } + } + } + + return nil +} + +type decoder struct { + rd io.Reader +} + +// read9p extracts values from rd and unmarshals them to the targets of vs. +func (d *decoder) decode(vs ...interface{}) error { + for _, v := range vs { + switch v := v.(type) { + case *string: + var ll uint16 + + // implement string[s] encoding + if err := binary.Read(d.rd, binary.LittleEndian, &ll); err != nil { + return err + } + + b := make([]byte, ll) + + n, err := io.ReadFull(d.rd, b) + if err != nil { + return err + } + + if n != int(ll) { + return fmt.Errorf("unexpected string length") + } + + *v = string(b) + case *Fcall: + var size uint32 + if err := d.decode(&size, &v.Type, &v.Tag); err != nil { + return err + } + + var err error + v.Message, err = newMessage(v.Type) + if err != nil { + return err + } + + if err := d.decode(v.Message); err != nil { + return err + } + case Message: + elements, err := fields9p(v) + if err != nil { + return err + } + + if err := d.decode(elements...); err != nil { + return err + } + default: + if err := binary.Read(d.rd, binary.LittleEndian, v); err != nil { + return err + } + } + } + + return nil +} + +// size9p calculates the projected size of the values in vs when encoded into +// 9p binary protocol. If an element or elements are not valid for 9p encoded, +// the value 0 will be used for the size. The error will be detected when +// encoding. +func size9p(vs ...interface{}) uint32 { + var s uint32 + for _, v := range vs { + if v == nil { + continue + } + + switch v := v.(type) { + case *string: + s += uint32(binary.Size(uint16(0)) + len(*v)) + case string: + s += uint32(binary.Size(uint16(0)) + len(v)) + case Message: + // walk the fields of the message to get the total size. we just + // use the field order from the message struct. We may add tag + // ignoring if needed. + elements, err := fields9p(v) + if err != nil { + // BUG(stevvooe): The options here are to return 0, panic or + // make this return an error. Ideally, we make it safe to + // return 0 and have the rest of the package do the right + // thing. For now, we do this, but may want to panic until + // things are stable. + return 0 + } + + s += size9p(elements...) + case Fcall: + s += size9p(v.Type, v.Tag, v.Message) + case *Fcall: + // Calculates the total size of the fcall, excluding the size + // header, which is handled exernally. The result of + // (*Fcall).MarshalBinary will have len(p) == the result of this + // branch. The value should be + s += size9p(v.Type, v.Tag, v.Message) + default: + s += uint32(binary.Size(v)) + } + } + + return s +} + +// fields9p lists the settable fields from a struct type for reading and +// writing. We are using a lot of reflection here for fairly static +// serialization but we can replace this in the future with generated code if +// performance is an issue. +func fields9p(v interface{}) ([]interface{}, error) { + rv := reflect.Indirect(reflect.ValueOf(v)) + + if rv.Kind() != reflect.Struct { + return nil, fmt.Errorf("cannot extract fields from non-struct: %v", rv) + } + + var elements []interface{} + for i := 0; i < rv.NumField(); i++ { + f := rv.Field(i) + + if !f.CanInterface() { + return nil, fmt.Errorf("can't interface: %v", f) + } + + if !f.CanSet() { + return nil, fmt.Errorf("cannot set %v", f) + } + + if f.CanAddr() { + f = f.Addr() + } + + elements = append(elements, f.Interface()) + } + + return elements, nil +} blob - 73b68dd64117abec8a2328e35d03d23921bd9f42 blob + d550f4532b8a7fa29594fd215f549c129367981e --- fcall.go +++ fcall.go @@ -1,14 +1,7 @@ package p9pnew -import ( - "bytes" - "encoding/binary" - "fmt" - "io" +import "fmt" - "encoding" -) - type FcallType uint8 const ( @@ -43,46 +36,99 @@ const ( Tmax ) +func (fct FcallType) String() string { + switch fct { + case Tversion: + return "Tversion" + case Rversion: + return "Rversion" + case Tauth: + return "Tauth" + case Rauth: + return "Rauth" + case Tattach: + return "Tattach" + case Rattach: + return "Rattach" + case Terror: + // invalid. + return "Terror" + case Rerror: + return "Rerror" + case Tflush: + return "Tflush" + case Rflush: + return "Rflush" + case Twalk: + return "Twalk" + case Rwalk: + return "Rwalk" + case Topen: + return "Topen" + case Ropen: + return "Ropen" + case Tcreate: + return "Tcreate" + case Rcreate: + return "Rcreate" + case Tread: + return "Tread" + case Rread: + return "Rread" + case Twrite: + return "Twrite" + case Rwrite: + return "Rwrite" + case Tclunk: + return "Tclunk" + case Rclunk: + return "Rclunk" + case Tremove: + return "Tremote" + case Rremove: + return "Rremove" + case Tstat: + return "Tstat" + case Rstat: + return "Rstat" + case Twstat: + return "Twstat" + case Rwstat: + return "Rwstat" + default: + return "Tunknown" + } +} + type Fcall struct { - Type Type + Type FcallType Tag Tag Message Message } -const ( - fcallHeaderSize = 4 /*size*/ + 1 /*type*/ -) - -func (fc *Fcall) Size() int { - return fcallHeaderSize + fc.Message.Size() +func (fc Fcall) String() string { + return fmt.Sprintf("%8d %v(%v) %v", size9p(fc), fc.Type, fc.Tag, fc.Message) } -func (fc *Fcall) MarshalBinary() ([]byte, error) { - mp, err := fc.Message.MarshalBinary() - if err != nil { - return nil, err - } +type Message interface { + // Size() uint32 - b := bytes.NewBuffer(make([]byte, 0, fc.Size())) - if err := write9p(b, fc.Size(), fc.Tag, mp); err != nil { - return nil, err - } + // NOTE(stevvooe): The binary marshal approach isn't particularly nice to + // generating garbage. Consider using an append model, once we have the + // messages worked out. + // encoding.BinaryMarshaler + // encoding.BinaryUnmarshaler - return b.Bytes(), nil + message9p() } -func (fc *Fcall) UnmarshalBinary(p []data) error { - var ( - r = bytes.NewReader(p) - ) - - if err := read9p(r, &fc.Type, &fc.Tag); err != nil { - return err - } - - switch fc.Type { +// newMessage returns a new instance of the message based on the Fcall type. +func newMessage(typ FcallType) (Message, error) { + // NOTE(stevvooe): This is a nasty bit of code but makes the transport + // fairly simple to implement. + switch typ { case Tversion, Rversion: - fc.Message = &MessageVersion{} + return &MessageVersion{}, nil case Tauth: case Rauth: @@ -96,9 +142,9 @@ func (fc *Fcall) UnmarshalBinary(p []data) error { case Rerror: case Tflush: - + return &MessageFlush{}, nil case Rflush: - + return nil, nil // No message body for this response. case Twalk: case Rwalk: @@ -134,18 +180,14 @@ func (fc *Fcall) UnmarshalBinary(p []data) error { case Twstat: case Rwstat: + default: + return nil, fmt.Errorf("unknown message type: %v", typ) } - return fc.Message.UnmarshalBinary(p[len(p)-r.Len():]) + return nil, fmt.Errorf("unknown message") } -type Message interface { - Size() int - encoding.BinaryMarshaler - encoding.BinaryUnmarshaler -} - // MessageVersion encodes the message body for Tversion and Rversion RPC // calls. The body is identical in both directions. type MessageVersion struct { @@ -153,109 +195,14 @@ type MessageVersion struct { Version string } -func (mv MessageVersion) Size() int { - return 4 + 2 + len(mv.Version) +func (MessageVersion) message9p() {} +func (mv MessageVersion) String() string { + return fmt.Sprintf("msize=%v version=%v", mv.MSize, mv.Version) } -func (mv MessageVersion) MarshalBinary() ([]byte, error) { - b := bytes.NewBuffer(make([]byte, 0, mv.Size())) - - if err := write9p(b, mv.MSize, mv.Version); err != nil { - return nil, err - } - - return b.Bytes(), nil +// MessageFlush handles the content for the Tflush message type. +type MessageFlush struct { + Oldtag Tag } -// write9p implements serialization for base types. -func write9p(w io.Writer, vs ...interface{}) error { - for _, v := range vs { - switch v := v.(type) { - case string: - // implement string[s] encoding - if err := binary.Write(w, binary.LittleEndian, uint16(len(v))); err != nil { - return err - } - - _, err := io.WriteString(w, s) - if err != nil { - - return err - } - case *Fcall: - if err := write9p(w, v.Size()); err != nil { - return err - } - p, err := v.MarshalBinary() - if err != nil { - return err - } - - n, err := w.Write(p) - if err != nil { - return err - } - - if n != len(p) { - return io.ErrShortWrite - } - - return nil - default: - if err := binary.Write(w, binary.LittleEndian, v); err != nil { - return err - } - } - } - - return nil -} - -// read9p extracts values from rd and unmarshals them to the targets of vs. -func read9p(rd io.Reader, vs ...interface{}) error { - for _, v := range vs { - switch v := v.(type) { - case *string: - var ll uint16 - - // implement string[s] encoding - if err := binary.Read(r, binary.LittleEndian, &ll); err != nil { - return err - } - - b := make([]byte, ll) - - n, err := io.ReadFull(b) - if err != nil { - return err - } - - if n != int(ll) { - return fmt.Errorf("unexpected string length") - } - - *v = string(b) - case *Fcall: - var size uint32 - if err := read9p(buffered, &size); err != nil { - return err - } - - p := make([]byte, size) - n, err := io.ReadFull(p) - if err != nil { - return err - } - - if n != size { - return fmt.Errorf("error reading fcall: short read") - } - - return v.UnmarshalBinary(p) - default: - if err := binary.Read(r, binary.LittleEndian, v); err != nil { - return err - } - } - } -} +func (MessageFlush) message9p() {} blob - /dev/null blob + 074f520e9f040f5f136f3dec00b7db9a713be4f1 (mode 644) --- /dev/null +++ encoding_test.go @@ -0,0 +1,72 @@ +package p9pnew + +import ( + "bytes" + "reflect" + "testing" +) + +func TestEncodeDecode(t *testing.T) { + for _, testcase := range []struct { + description string + target interface{} + marshaled []byte + }{ + { + description: "string", + target: "asdf", + marshaled: []byte{0x4, 0x0, 0x61, 0x73, 0x64, 0x66}, + }, + { + description: "Tversion fcall", + target: &Fcall{ + Type: Tversion, + Tag: 2255, + Message: &MessageVersion{ + MSize: uint32(1024), + Version: "9PTEST", + }, + }, + marshaled: []byte{0xf, 0x0, 0x0, 0x0, 0x64, 0xcf, 0x8, 0x0, 0x4, 0x0, 0x0, 0x6, 0x0, 0x39, 0x50, 0x54, 0x45, 0x53, 0x54}, + }, + } { + t.Logf("target under test: %v", testcase.target) + fatalf := func(format string, args ...interface{}) { + t.Fatalf(testcase.description+": "+format, args...) + } + + var b bytes.Buffer + + enc := &encoder{&b} + dec := &decoder{&b} + + if err := enc.encode(testcase.target); err != nil { + fatalf("error writing fcall: %v", err) + } + + if !bytes.Equal(b.Bytes(), testcase.marshaled) { + fatalf("unexpected bytes for fcall: %#v != %#v", b.Bytes(), testcase.marshaled) + } + + var v interface{} + targetType := reflect.TypeOf(testcase.target) + + if targetType.Kind() == reflect.Ptr { + v = reflect.New(targetType.Elem()).Interface() + } else { + v = reflect.New(targetType).Interface() + } + + if err := dec.decode(v); err != nil { + fatalf("error reading fcall: %v", err) + } + + if targetType.Kind() != reflect.Ptr { + v = reflect.Indirect(reflect.ValueOf(v)).Interface() + } + + if !reflect.DeepEqual(v, testcase.target) { + fatalf("not equal: %#v != %#v", v, testcase.target) + } + } +} blob - df87804033610f46a568197526e072fb5292a7c9 blob + f1bae724b08d7812eafd714ad375fcaa413085ed --- logging.go +++ logging.go @@ -1,3 +1,5 @@ +// +build ignore + package p9pnew import ( blob - 287a302388f8565d8ce9c6fcb628fa154a9d1a7a blob + 032f07f44d7148b2ac5d6755d3afae997b5606f5 --- session.go +++ session.go @@ -33,7 +33,7 @@ type Session interface { // TODO(stevvooe): The version message affects a lot of protocol behavior. // Consider hiding it behind the implementation, letting the version get // negotiated. The API user should still be able to query it. - Version(ctx context.Context, msize int32, version string) (int32, string, error) + Version(ctx context.Context, msize uint32, version string) (uint32, string, error) } func Dial(addr string) (Session, error) { blob - 23ab0c4ea3426ea68d3967bfc611167ad034ef77 blob + 524e7b4ec21e4fcfb010d59e7ba62e07fb3a78f1 --- types.go +++ types.go @@ -1,11 +1,7 @@ package p9pnew -import ( - "encoding" +import "time" - "time" -) - const ( NOFID = ^Fid(0) NOTAG = ^Tag(0) @@ -75,60 +71,3 @@ type Dir struct { // type Tag uint16 - -type FcallType uint8 - -const ( - FcallTypeTversion FcallType = iota + 100 - FcallTypeRversion - FcallTypeTauth - FcallTypeRauth - FcallTypeTattach - FcallTypeRattach - FcallTypeTerror - FcallTypeRerror - FcallTypeTflush - FcallTypeRflush - FcallTypeTwalk - FcallTypeRwalk - FcallTypeTopen - FcallTypeRopen - FcallTypeTcreate - FcallTypeRcreate - FcallTypeTread - FcallTypeRread - FcallTypeTwrite - FcallTypeRwrite - FcallTypeTclunk - FcallTypeRclunk - FcallTypeTremove - FcallTypeRremove - FcallTypeTstat - FcallTypeRstat - FcallTypeTwstat - FcallTypeRwstat - FcallTypeTmax -) - -type Fcall struct { - Type Type - Fid Fid - Tag Tag - Message Message -} - -type Message interface { - Size() int - encoding.BinaryMarshaler - encoding.BinaryUnmarshaler -} - -type MessageVersion struct { - MSize uint32 - Version string -} - -func (mv MessageVersion) MarshalBinary() ([]byte, error) { - - encoding.BinaryMarshaler -} blob - /dev/null blob + 1074f83dc94bc6e9447cf920226f37465cd3d61c (mode 644) --- /dev/null +++ server.go @@ -0,0 +1,8 @@ +package p9pnew + +import "net" + +// Serve the 9p session over the provided network connection. +func Serve(conn net.Conn, session Session) error { + panic("not implemented") +}