Commit Diff


commit - 5817aa44982898c5130ee119f4a9964e00741401
commit + 3f03529f911621de2dedee358e728d5ac49fbb74
blob - c218227956175b3d089238ca1934f43f54e80105
blob + 31864c7907dbaf020bfd191bfa18a0c983b4dd75
--- ftp.c
+++ ftp.c
@@ -20,6 +20,7 @@
 #include <sys/socket.h>
 
 #include <assert.h>
+#include <errno.h>
 #include <netdb.h>
 #include <limits.h>
 #include <stdio.h>
@@ -112,19 +113,27 @@ usage(int ret)
 static void
 do_send(void)
 {
-	ssize_t r;
+	const void	*buf;
+	size_t		 nbytes;
+	ssize_t		 r;
 
 	while (EVBUFFER_LENGTH(evb) != 0) {
-		r = tls_write(ctx, EVBUFFER_DATA(evb), EVBUFFER_LENGTH(evb));
-		switch (r) {
-		case TLS_WANT_POLLIN:
-		case TLS_WANT_POLLOUT:
-			continue;
-		case -1:
-			errx(1, "tls: %s", tls_error(ctx));
-		default:
-			evbuffer_drain(evb, r);
+		buf = EVBUFFER_DATA(evb);
+		nbytes = EVBUFFER_LENGTH(evb);
+
+		if (ctx == NULL) {
+			r = write(sock, buf, nbytes);
+			if (r == 0 || r == -1)
+				errx(1, "EOF");
+		} else {
+			r = tls_write(ctx, buf, nbytes);
+			if (r == TLS_WANT_POLLIN || r == TLS_WANT_POLLOUT)
+				continue;
+			if (r == -1)
+				errx(1, "tls: %s", tls_error(ctx));
 		}
+
+		evbuffer_drain(evb, r);
 	}
 }
 
@@ -134,24 +143,27 @@ mustread(void *d, size_t len)
 	ssize_t r;
 
 	while (len != 0) {
-		switch (r = tls_read(ctx, d, len)) {
-		case TLS_WANT_POLLIN:
-		case TLS_WANT_POLLOUT:
-			continue;
-		case -1:
-			errx(1, "tls: %s", tls_error(ctx));
-		default:
-			d += r;
-			len -= r;
+		if (ctx == NULL) {
+			r = read(sock, d, len);
+			if (r == 0 || r == -1)
+				errx(1, "EOF");
+		} else {
+			r = tls_read(ctx, d, len);
+			if (r == TLS_WANT_POLLIN || r == TLS_WANT_POLLOUT)
+				continue;
+			if (r == -1)
+				errx(1, "tls: %s", tls_error(ctx));
 		}
+
+		d += r;
+		len -= r;
 	}
 }
 
 static void
 recv_msg(void)
 {
-	uint32_t	len;
-	ssize_t		r;
+	uint32_t	len, l;
 	char		tmp[BUFSIZ];
 
 	mustread(&len, sizeof(len));
@@ -162,16 +174,10 @@ recv_msg(void)
 	len -= 4; /* skip the length just read */
 
 	while (len != 0) {
-		switch (r = tls_read(ctx, tmp, sizeof(tmp))) {
-		case TLS_WANT_POLLIN:
-		case TLS_WANT_POLLOUT:
-			continue;
-		case -1:
-			errx(1, "tls: %s", tls_error(ctx));
-		default:
-			len -= r;
-			evbuffer_add(buf, tmp, r);
-		}
+		l = MIN(len, sizeof(tmp));
+		mustread(tmp, l);
+		len -= l;
+		evbuffer_add(buf, tmp, l);
 	}
 }
 
@@ -360,22 +366,10 @@ dup_fid(int fid, int nfid)
 }
 
 static void
-do_connect(const char *connspec, const char *path)
+do_tls_connect(const char *host, const char *port)
 {
 	int handshake;
-	char *host, *colon;
-	const char *port;
 
-	host = xstrdup(connspec);
-	if ((colon = strchr(host, ':')) != NULL) {
-		*colon = '\0';
-		port = ++colon;
-	} else
-		port = "1337";
-
-	if (!tls)
-		fatalx("non-tls mode is not supported");
-
 	if ((tlsconf = tls_config_new()) == NULL)
 		fatalx("tls_config_new");
 	tls_config_insecure_noverifycert(tlsconf);
@@ -388,9 +382,6 @@ do_connect(const char *connspec, const char *path)
 	if (tls_configure(ctx, tlsconf) == -1)
 		fatalx("tls_configure: %s", tls_error(ctx));
 
-	printf("connecting to %s:%s...", host, port);
-	fflush(stdout);
-
 	if (tls_connect(ctx, host, port) == -1)
 		fatalx("can't connect to %s:%s: %s", host, port,
 		    tls_error(ctx));
@@ -404,7 +395,69 @@ do_connect(const char *connspec, const char *path)
 			break;
 		}
 	}
+}
 
+static void
+do_ctxt_connect(const char *host, const char *port)
+{
+	struct addrinfo hints, *res, *res0;
+	int error, saved_errno;
+	const char *cause = NULL;
+
+	memset(&hints, 0, sizeof(hints));
+	hints.ai_family = AF_UNSPEC;
+	hints.ai_socktype = SOCK_STREAM;
+	error = getaddrinfo(host, port, &hints, &res0);
+	if (error)
+		errx(1, "%s", gai_strerror(error));
+
+	sock = -1;
+	for (res = res0; res != NULL; res = res->ai_next) {
+		sock = socket(res->ai_family, res->ai_socktype,
+		    res->ai_protocol);
+		if (sock == -1) {
+			cause = "socket";
+			continue;
+		}
+
+		if (connect(sock, res->ai_addr, res->ai_addrlen) == -1) {
+			cause = "connect";
+			saved_errno = errno;
+			close(sock);
+			errno = saved_errno;
+			sock = -1;
+			continue;
+		}
+
+		break;
+	}
+
+	if (sock == -1)
+		err(1, "%s", cause);
+	freeaddrinfo(res0);
+}
+
+static void
+do_connect(const char *connspec, const char *path)
+{
+	char *host, *colon;
+	const char *port;
+
+	host = xstrdup(connspec);
+	if ((colon = strchr(host, ':')) != NULL) {
+		*colon = '\0';
+		port = ++colon;
+	} else
+		port = "1337";
+
+	printf("connecting to %s:%s...", host, port);
+	fflush(stdout);
+
+	if (tls)
+		do_tls_connect(host, port);
+	else
+		do_ctxt_connect(host, port);
+
 	printf(" done!\n");
 
 	do_version();