Commit Diff


commit - c23dcc663854db7ecbfb882d85d7e2356b4fd355
commit + 8b0a506230428aaacefd78995955e80f058d9c9d
blob - e7608b4dfe2532e3d7464470c55310f0d9230d27
blob + f1f81fb14b30948b989307846bb4992080644f53
--- kamid/client.c
+++ kamid/client.c
@@ -43,15 +43,6 @@
 #include "utils.h"
 
 /*
- * XXX: atm is difficult to accept messages bigger than MAX_IMSGSIZE
- * minus IMSG_HEADER_SIZE, we need something to split messages into
- * chunks and receive them one by the other.
- *
- * CLIENT_MSIZE is thus the maximum message size we can handle now.
- */
-#define CLIENT_MSIZE (MAX_IMSGSIZE - IMSG_HEADER_SIZE)
-
-/*
  * The minimum value allowed for the msize.
  */
 #define MIN_MSIZE 256
@@ -186,11 +177,12 @@ static void	twalk(struct np_msg_header *, const uint8_
 static void	topen(struct np_msg_header *, const uint8_t *, size_t);
 static void	tcreate(struct np_msg_header *, const uint8_t *, size_t);
 static void	tread(struct np_msg_header *, const uint8_t *, size_t);
-static void	twrite(struct np_msg_header *, const uint8_t *, size_t);
+static void	twrite(struct np_msg_header *, const uint8_t *, size_t, struct fid **, off_t *, size_t *, int *);
+static void	twrite_cont(struct fid *, off_t *, size_t *, int *, uint16_t, const uint8_t *, size_t);
 static void	tstat(struct np_msg_header *, const uint8_t *, size_t);
 static void	twstat(struct np_msg_header *, const uint8_t *, size_t);
 static void	tremove(struct np_msg_header *, const uint8_t *, size_t);
-static void	handle_message(struct imsg *, size_t);
+static void	handle_message(struct imsg *, size_t, int);
 
 __dead void
 client(int debug, int verbose)
@@ -337,10 +329,12 @@ client_dispatch_listener(int fd, short event, void *d)
 			explicit_bzero(&rauth, sizeof(rauth));
 			break;
 		case IMSG_BUF:
+		case IMSG_BUF_CONT:
 			if (!auth)
 				fatalx("%s: can't handle messages before"
 				    " doing the auth", __func__);
-			handle_message(&imsg, IMSG_DATA_SIZE(imsg));
+			handle_message(&imsg, IMSG_DATA_SIZE(imsg),
+			    imsg.hdr.type == IMSG_BUF_CONT);
 			break;
 		case IMSG_CONN_GONE:
 			log_debug("closing");
@@ -962,7 +956,7 @@ tversion(struct np_msg_header *hdr, const uint8_t *dat
 
 	/* version matched */
 	handshaked = 1;
-	msize = MIN(msize, CLIENT_MSIZE);
+	msize = MIN(msize, MSIZE9P);
 	client_send_listener(IMSG_MSIZE, &msize, sizeof(msize));
 	np_version(hdr->tag, msize, VERSION9P);
 	return;
@@ -1555,7 +1549,8 @@ tread(struct np_msg_header *hdr, const uint8_t *data, 
 }
 
 static void
-twrite(struct np_msg_header *hdr, const uint8_t *data, size_t len)
+twrite(struct np_msg_header *hdr, const uint8_t *data, size_t len,
+    struct fid **writefid, off_t *writepos, size_t *writeleft, int *writeskip)
 {
 	struct fid	*f;
 	ssize_t		 r;
@@ -1566,36 +1561,73 @@ twrite(struct np_msg_header *hdr, const uint8_t *data,
 	if (!NPREAD32("fid", &fid, &data, &len) ||
 	    !NPREAD64("off", &off, &data, &len) ||
 	    !NPREAD32("count", &count, &data, &len) ||
-	    len != count) {
+	    count < len) {
 		client_send_listener(IMSG_CLOSE, NULL, 0);
 		client_shutdown();
 		return;
 	}
 
 	if ((f = fid_by_id(fid)) == NULL || f->fd == -1) {
+		*writeskip = 1;
 		np_error(hdr->tag, "invalid fid");
 		return;
 	}
 
 	if (!(f->iomode & O_WRONLY) &&
 	    !(f->iomode & O_RDWR)) {
+		*writeskip = 1;
 		np_error(hdr->tag, "fid not opened for writing");
 		return;
 	}
 
 	if (TYPE_OVERFLOW(off_t, off)) {
+		*writeskip = 1;
 		log_warnx("unexpected off_t size");
 		np_error(hdr->tag, "invalid offset");
 		return;
 	}
 
-	if ((r = pwrite(f->fd, data, len, off)) == -1)
-		np_errno(hdr->tag);
-	else
+	if ((r = pwrite(f->fd, data, len, off)) == -1) {
+		*writeskip = 1;
+		np_errno(hdr->tag);
+	} else if (count == len)
 		np_write(hdr->tag, r);
+
+	/* account for a continuated write */
+	if (count > len) {
+		*writefid = f;
+		*writepos = off + len;
+		*writeleft = count - len;
+		*writeskip = 0;
+	}
 }
 
 static void
+twrite_cont(struct fid *f, off_t *writepos, size_t *writeleft, int *writeskip,
+    uint16_t tag, const uint8_t *data, size_t len)
+{
+	ssize_t r;
+
+	if (len > *writeleft) {
+		client_send_listener(IMSG_CLOSE, NULL, 0);
+		client_shutdown();
+		return;
+	}
+
+	if ((r = pwrite(f->fd, data, len, *writepos)) == -1) {
+		*writeskip = 1;
+		np_errno(tag);
+		return;
+	}
+
+	*writeleft -= len;
+	*writepos += len;
+
+	if (*writeleft == 0)
+		np_write(tag, r);
+}
+
+static void
 tstat(struct np_msg_header *hdr, const uint8_t *data, size_t len)
 {
 	struct evbuffer	*evb;
@@ -1857,8 +1889,13 @@ tremove(struct np_msg_header *hdr, const uint8_t *data
 }
 
 static void
-handle_message(struct imsg *imsg, size_t len)
+handle_message(struct imsg *imsg, size_t len, int cont)
 {
+	static struct fid *writefid;
+	static off_t writepos;
+	static size_t writeleft;
+	static int writeskip;
+	static uint16_t writetag;
 	struct msg {
 		uint8_t	 type;
 		void	(*fn)(struct np_msg_header *, const uint8_t *, size_t);
@@ -1871,7 +1908,7 @@ handle_message(struct imsg *imsg, size_t len)
 		{Topen,		topen},
 		{Tcreate,	tcreate},
 		{Tread,		tread},
-		{Twrite,	twrite},
+		/* {Twrite,	twrite}, */
 		{Tstat,		tstat},
 		{Twstat,	twstat},
 		{Tremove,	tremove},
@@ -1884,6 +1921,41 @@ handle_message(struct imsg *imsg, size_t len)
 	hexdump("incoming packet", imsg->data, len);
 #endif
 
+	/*
+	 * Twrite is special and can be "continued" to allow writing
+	 * more than what the imsg framework would allow us to.
+	 */
+	if (writeleft > 0 && !cont) {
+		log_warnx("received a non continuation message when still "
+		    "missed %zu bytes to write", writeleft);
+		client_send_listener(IMSG_CLOSE, NULL, 0);
+		client_shutdown();
+		return;
+	}
+
+	if (cont) {
+		if (writeskip)
+			return;
+
+		if (writefid == NULL) {
+			log_warnx("received a continuation message without "
+			    "seeing a Twrite");
+			client_send_listener(IMSG_CLOSE, NULL, 0);
+			client_shutdown();
+			return;
+		}
+
+		log_warnx("continuing...");
+		twrite_cont(writefid, &writepos, &writeleft, &writeskip,
+		    writetag, imsg->data, len);
+		return;
+	}
+
+	writefid = NULL;
+	writepos = -1;
+	writeleft = 0;
+	writeskip = 0;
+
 	parse_message(imsg->data, len, &hdr, &data);
 	len -= HEADERSIZE;
 
@@ -1896,6 +1968,13 @@ handle_message(struct imsg *imsg, size_t len)
 		return;
 	}
 
+	if (hdr.type == Twrite) {
+		writetag = hdr.tag;
+		twrite(&hdr, data, len, &writefid, &writepos, &writeleft,
+		    &writeskip);
+		return;
+	}
+
 	for (i = 0; i < sizeof(msgs)/sizeof(msgs[0]); ++i) {
 		if (msgs[i].type != hdr.type)
 			continue;
blob - cd4d70bab61ef9d618e24fb7d443c305fc293d04
blob + f076107c77bf22a64cbe48841f2a93e082efa10a
--- kamid/kamid.h
+++ kamid/kamid.h
@@ -46,6 +46,7 @@ enum imsg_type {
 	IMSG_AUTH_TLS,		/* kd_auth_req */
 	IMSG_CONN_GONE,
 	IMSG_BUF,
+	IMSG_BUF_CONT,
 	IMSG_MSIZE,
 	IMSG_CLOSE,
 };
blob - ec2b18df206a2e778d1b1d7fa9417861adcdca45
blob + 03946221c35c7b1834f4f6cb43b38424245f44e8
--- kamid/listener.c
+++ kamid/listener.c
@@ -44,6 +44,8 @@
 #include "sandbox.h"
 #include "utils.h"
 
+#define IMSG_MAXSIZE (MAX_IMSGSIZE - IMSG_HEADER_SIZE)
+
 static struct kd_conf	*listener_conf;
 static struct imsgev	*iev_main;
 
@@ -57,6 +59,7 @@ struct client {
 	uint32_t		 lid;
 	uint32_t		 lflags;
 	uint32_t		 msize;
+	uint32_t		 left;
 	int			 fd;
 	struct tls		*ctx;
 	struct event		 event;
@@ -483,7 +486,7 @@ listener_dispatch_client(int fd, short event, void *d)
 
 			if (client->msize == 0)
 				fatal("IMSG_MSIZE got msize = 0");
-
+			log_debug("set msize to %d", client->msize);
 			break;
 
 		case IMSG_CLOSE:
@@ -689,17 +692,33 @@ client_read(struct bufferevent *bev, void *d)
 {
 	struct client	*client = d;
 	struct evbuffer	*src = EVBUFFER_INPUT(bev);
-	uint32_t	 len;
+	size_t evlen;
+	uint32_t len;
 
 	for (;;) {
-		if (EVBUFFER_LENGTH(src) < 4)
+		evlen = EVBUFFER_LENGTH(src);
+
+		if (client->left != 0) {
+			/* wait to fill a whole imsg if possible */
+			if (client->left >= IMSG_MAXSIZE &&
+			    evlen < IMSG_MAXSIZE)
+				return;
+
+			len = MIN(client->left, evlen);
+			listener_imsg_compose_client(client, IMSG_BUF_CONT,
+			    client->id, EVBUFFER_DATA(src), len);
+			evbuffer_drain(src, len);
+			client->left -= len;
+			continue;
+		}
+
+		if (evlen < 4)
 			return;
 
 		memcpy(&len, EVBUFFER_DATA(src), sizeof(len));
 		len = le32toh(len);
 		log_debug("expecting a message %"PRIu32" bytes long "
-		    "(of wich %zu already read)",
-		    len, EVBUFFER_LENGTH(src));
+		    "(of wich %zu already read)", len, evlen);
 
 		if (len < HEADERSIZE) {
 			log_warnx("invalid message size %d (too low)", len);
@@ -714,7 +733,15 @@ client_read(struct bufferevent *bev, void *d)
 			return;
 		}
 
-		if (len > EVBUFFER_LENGTH(src))
+		if (len > IMSG_MAXSIZE && evlen >= len) {
+			listener_imsg_compose_client(client, IMSG_BUF,
+			    client->id, EVBUFFER_DATA(src), IMSG_MAXSIZE);
+			evbuffer_drain(src, IMSG_MAXSIZE);
+			client->left = len - IMSG_MAXSIZE;
+			continue;
+		}
+
+		if (len > evlen)
 			return;
 
 		listener_imsg_compose_client(client, IMSG_BUF, client->id,