Commit Diff


commit - 2dfdd11a40cc7c2f468c0b614f7d0aaffa1c6074
commit + 8f7253b29f275cfc3fd6111a54ce1cdd6622d1ee
blob - c45a5cb11b0b375cb1f32e4716c74b48d44058fa
blob + 2d4b1af1e1145afbda93fa3e94522a2bd456f5a6
--- kamiftp/ftp.c
+++ kamiftp/ftp.c
@@ -64,9 +64,7 @@ char		*host;
 char		*port;
 
 /* state */
-struct tls_config	*tlsconf;
-struct tls		*ctx;
-int			 sock;
+FILE			*fp;
 struct evbuffer		*buf;
 struct evbuffer		*dirbuf;
 uint32_t		 msize;
@@ -165,6 +163,51 @@ spawn(const char *argv0, ...)
 	}
 }
 
+static int
+stdio_tls_write(void *arg, const char *buf, int len)
+{
+	struct tls	*ctx = arg;
+	ssize_t		 ret;
+
+	do {
+		ret = tls_write(ctx, buf, len);
+	} while (ret == TLS_WANT_POLLIN || ret == TLS_WANT_POLLOUT);
+
+	if (ret == -1)
+		warn("tls_write: %s", tls_error(ctx));
+
+	return ret;
+}
+
+static int
+stdio_tls_read(void *arg, char *buf, int len)
+{
+	struct tls	*ctx = arg;
+	ssize_t		 ret;
+
+	do {
+		ret = tls_read(ctx, buf, len);
+	} while (ret == TLS_WANT_POLLIN || ret == TLS_WANT_POLLOUT);
+
+	if (ret == -1)
+		warn("tls_read: %s", tls_error(ctx));
+
+	return ret;
+}
+
+static int
+stdio_tls_close(void *arg)
+{
+	struct tls	*ctx = arg;
+	int		 ret;
+
+	do {
+		ret = tls_close(ctx);
+	} while (ret == TLS_WANT_POLLIN || ret == TLS_WANT_POLLOUT);
+
+	return ret;
+}
+
 static void
 tty_resized(int signo)
 {
@@ -194,30 +237,16 @@ nextfid(void)
 static void
 do_send(void)
 {
-	const void	*buf;
-	size_t		 nbytes;
-	ssize_t		 r;
+	size_t		 r;
 
 	if (xdump)
 		hexdump("outgoing message", EVBUFFER_DATA(evb),
 		    EVBUFFER_LENGTH(evb));
 
 	while (EVBUFFER_LENGTH(evb) != 0) {
-		buf = EVBUFFER_DATA(evb);
-		nbytes = EVBUFFER_LENGTH(evb);
-
-		if (ctx == NULL) {
-			r = write(sock, buf, nbytes);
-			if (r == 0 || r == -1)
-				errx(1, "EOF");
-		} else {
-			r = tls_write(ctx, buf, nbytes);
-			if (r == TLS_WANT_POLLIN || r == TLS_WANT_POLLOUT)
-				continue;
-			if (r == -1)
-				errx(1, "tls: %s", tls_error(ctx));
-		}
-
+		r = fwrite(EVBUFFER_DATA(evb), 1, EVBUFFER_LENGTH(evb), fp);
+		if (r == 0)
+			fatalx("unexpected EOF");
 		evbuffer_drain(evb, r);
 	}
 }
@@ -225,33 +254,24 @@ do_send(void)
 static void
 mustread(void *d, size_t len)
 {
-	ssize_t r;
+	size_t r;
 
-	while (len != 0) {
-		if (ctx == NULL) {
-			r = read(sock, d, len);
-			if (r == 0 || r == -1)
-				errx(1, "EOF");
-		} else {
-			r = tls_read(ctx, d, len);
-			if (r == TLS_WANT_POLLIN || r == TLS_WANT_POLLOUT)
-				continue;
-			if (r == -1)
-				errx(1, "tls: %s", tls_error(ctx));
-		}
-
-		d += r;
-		len -= r;
-	}
+	r = fread(d, 1, len, fp);
+	if (r != len)
+		errx(1, "unexpected EOF");
 }
 
 static void
 recv_msg(void)
 {
+	size_t		r;
 	uint32_t	len, l;
 	char		tmp[BUFSIZ];
 
-	mustread(&len, sizeof(len));
+	r = fread(&len, 1, sizeof(len), fp);
+	if (r != sizeof(len))
+		errx(1, "unexpected EOF");
+
 	len = le32toh(len);
 	if (len < HEADERSIZE)
 		errx(1, "read message of invalid length %d", len);
@@ -260,9 +280,12 @@ recv_msg(void)
 
 	while (len != 0) {
 		l = MIN(len, sizeof(tmp));
-		mustread(tmp, l);
-		len -= l;
-		evbuffer_add(buf, tmp, l);
+
+		r = fread(tmp, 1, l, fp);
+		if (r != l)
+			errx(1, "unexpected EOF");
+		len -= r;
+		evbuffer_add(buf, tmp, r);
 	}
 
 	if (xdump)
@@ -848,49 +871,13 @@ woc_file(int fd, const char *prompt, const char *path)
 
 	send_fid(nfid, n, KOTRUNC, fd, prompt);
 	return 0;
-}
-
-static void
-do_tls_connect(const char *host, const char *port)
-{
-	int handshake;
-
-	if ((tlsconf = tls_config_new()) == NULL)
-		fatalx("tls_config_new");
-	tls_config_insecure_noverifycert(tlsconf);
-	tls_config_insecure_noverifyname(tlsconf);
-
-	if (keypath == NULL)
-		keypath = crtpath;
-
-	if (tls_config_set_keypair_file(tlsconf, crtpath, keypath) == -1)
-		fatalx("can't load certs (%s, %s)", crtpath, keypath);
-
-	if ((ctx = tls_client()) == NULL)
-		fatal("tls_client");
-	if (tls_configure(ctx, tlsconf) == -1)
-		fatalx("tls_configure: %s", tls_error(ctx));
-
-	if (tls_connect(ctx, host, port) == -1)
-		fatalx("can't connect to %s:%s: %s", host, port,
-		    tls_error(ctx));
-
-	for (handshake = 0; !handshake;) {
-		switch (tls_handshake(ctx)) {
-		case -1:
-			fatalx("tls_handshake: %s", tls_error(ctx));
-		case 0:
-			handshake = 1;
-			break;
-		}
-	}
 }
 
-static void
-do_ctxt_connect(const char *host, const char *port)
+static int
+dial(const char *host, const char *port)
 {
 	struct addrinfo hints, *res, *res0;
-	int error, saved_errno;
+	int sock, error, saved_errno;
 	const char *cause = NULL;
 
 	memset(&hints, 0, sizeof(hints));
@@ -920,10 +907,10 @@ do_ctxt_connect(const char *host, const char *port)
 
 		break;
 	}
-
+	freeaddrinfo(res0);
 	if (sock == -1)
 		err(1, "%s", cause);
-	freeaddrinfo(res0);
+	return sock;
 }
 
 static void
@@ -953,11 +940,48 @@ do_connect(const char *connspec, const char *path)
 	printf("connecting to %s:%s...", host, port);
 	fflush(stdout);
 
-	if (tls)
-		do_tls_connect(host, port);
-	else
-		do_ctxt_connect(host, port);
+	if (tls) {
+		struct tls_config	*conf;
+		struct tls		*ctx;
+		int			 r;
 
+		if ((conf = tls_config_new()) == NULL)
+			fatalx("failed to create TLS config");
+		tls_config_insecure_noverifycert(conf);
+		tls_config_insecure_noverifyname(conf);
+
+		if (keypath == NULL)
+			keypath = crtpath;
+
+		if (tls_config_set_keypair_file(conf, crtpath, keypath) == -1)
+			fatalx("failed to load certs: (%s, %s)", crtpath,
+			    keypath);
+
+		if ((ctx = tls_client()) == NULL)
+			fatalx("failed to create TLS client");
+		if (tls_configure(ctx, conf) == -1)
+			fatalx("failed to configure TLS client");
+		tls_config_free(conf);
+
+		if (tls_connect(ctx, host, port) == -1)
+			fatalx("failed to connect to %s:%s: %s", host, port,
+			    tls_error(ctx));
+
+		do {
+			r = tls_handshake(ctx);
+		} while (r == TLS_WANT_POLLIN || r == TLS_WANT_POLLOUT);
+		fp = funopen(ctx, stdio_tls_read, stdio_tls_write, NULL,
+		    stdio_tls_close);
+		if (fp == NULL)
+			fatal("funopen");
+	} else {
+		int fd;
+
+		fd = dial(host, port);
+		if ((fp = fdopen(fd, "r+")) == NULL)
+			fatal("fdopen");
+	}
+
 	printf(" done!\n");
 
 	do_version();
@@ -1056,6 +1080,7 @@ static void
 cmd_bye(int argc, const char **argv)
 {
 	log_warnx("bye\n");
+	fclose(fp);
 	exit(0);
 }
 
@@ -1658,4 +1683,5 @@ main(int argc, char **argv)
 	}
 
 	printf("\n");
+	fclose(fp);
 }