Blob


1 /*
2 * Copyright (c) 2021 Omar Polo <op@omarpolo.com>
3 *
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15 */
17 #include "compat.h"
19 #include <sys/types.h>
20 #include <sys/socket.h>
22 #include <netdb.h>
24 #include <assert.h>
25 #include <endian.h>
26 #include <errno.h>
27 #include <fcntl.h>
28 #include <inttypes.h>
29 #include <signal.h>
30 #include <stdio.h>
31 #include <stdlib.h>
32 #include <string.h>
33 #include <syslog.h>
34 #include <tls.h>
35 #include <unistd.h>
37 #include "kamid.h"
38 #include "log.h"
39 #include "utils.h"
41 #define DEBUG_PACKETS 0
43 #define PROMPT "=% "
45 /* flags */
46 int verbose;
47 int tls;
48 const char *keypath;
49 const char *crtpath;
50 const char *host;
51 const char *port;
53 /* state */
54 struct tls_config *tlsconf;
55 struct tls *ctx;
56 struct bufferevent *bev, *inbev;
58 static void ATTR_DEAD usage(int);
60 static void sig_handler(int, short, void *);
62 static int openconn(void);
63 static void mark_nonblock(int);
65 static void tls_readcb(int, short, void *);
66 static void tls_writecb(int, short, void *);
68 static void client_read(struct bufferevent *, void *);
69 static void client_write(struct bufferevent *, void *);
70 static void client_error(struct bufferevent *, short, void *);
72 static void repl_read(struct bufferevent *, void *);
73 static void repl_error(struct bufferevent *, short, void *);
74 static void write_hdr(uint32_t, uint8_t, uint16_t);
75 static void write_hdr_auto(uint32_t, uint8_t);
76 static void write_str(uint16_t, const char *);
77 static void write_str_auto(const char *);
78 static void write_fid(uint32_t);
79 static void write_tag(uint16_t);
81 static void excmd_version(const char **, int);
82 static void excmd_attach(const char **, int);
83 static void excmd_clunk(const char **, int);
84 static void excmd_flush(const char **, int);
85 static void excmd_walk(const char ** , int);
86 static void excmd(const char **, int);
88 static const char *pp_qid_type(uint8_t);
89 static void pp_qid(const uint8_t *, uint32_t);
90 static void pp_msg(uint32_t, uint8_t, uint16_t, const uint8_t *);
91 static void handle_9p(const uint8_t *, size_t);
92 static void clr(void);
93 static void prompt(void);
95 static void ATTR_DEAD
96 usage(int ret)
97 {
98 fprintf(stderr,
99 "usage: %s [-chv] [-C crt] [-K key] [-H host] [-P port]\n",
100 getprogname());
101 fprintf(stderr, PACKAGE_NAME " suite version " PACKAGE_VERSION "\n");
102 exit(ret);
105 static void
106 sig_handler(int sig, short event, void *d)
108 /*
109 * Normal signal handler rules don't apply because libevent
110 * decouples for us.
111 */
113 switch (sig) {
114 case SIGINT:
115 case SIGTERM:
116 clr();
117 log_warnx("Shutting down...");
118 event_loopbreak();
119 return;
120 default:
121 fatalx("unexpected signal %d", sig);
125 static int
126 openconn(void)
128 struct addrinfo hints, *res, *res0;
129 int error;
130 int save_errno;
131 int s;
132 const char *cause = NULL;
134 memset(&hints, 0, sizeof(hints));
135 hints.ai_family = AF_UNSPEC;
136 hints.ai_socktype = SOCK_STREAM;
137 if ((error = getaddrinfo(host, port, &hints, &res0))) {
138 warnx("%s", gai_strerror(error));
139 return -1;
142 s = -1;
143 for (res = res0; res; res = res->ai_next) {
144 s = socket(res->ai_family, res->ai_socktype,
145 res->ai_protocol);
146 if (s == -1) {
147 cause = "socket";
148 continue;
151 if (connect(s, res->ai_addr, res->ai_addrlen) == -1) {
152 cause = "connect";
153 save_errno = errno;
154 close(s);
155 errno = save_errno;
156 s = -1;
157 continue;
160 break;
163 freeaddrinfo(res0);
165 if (s == -1)
166 warn("%s", cause);
168 return s;
171 static void
172 mark_nonblock(int fd)
174 int flags;
176 if ((flags = fcntl(fd, F_GETFL)) == -1)
177 fatal("fcntl(F_GETFL)");
178 if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1)
179 fatal("fcntl(F_SETFL)");
182 static void
183 tls_readcb(int fd, short event, void *d)
185 struct bufferevent *bufev = d;
186 char buf[IBUF_READ_SIZE];
187 int what = EVBUFFER_READ;
188 int howmuch = IBUF_READ_SIZE;
189 ssize_t ret;
190 size_t len;
192 if (event == EV_TIMEOUT) {
193 what |= EVBUFFER_TIMEOUT;
194 goto err;
197 if (bufev->wm_read.high != 0)
198 howmuch = MIN(sizeof(buf), bufev->wm_read.high);
200 switch (ret = tls_read(ctx, buf, howmuch)) {
201 case TLS_WANT_POLLIN:
202 case TLS_WANT_POLLOUT:
203 goto retry;
204 case -1:
205 what |= EVBUFFER_ERROR;
206 goto err;
208 len = ret;
210 if (len == 0) {
211 what |= EVBUFFER_EOF;
212 goto err;
215 if (evbuffer_add(bufev->input, buf, len) == -1) {
216 what |= EVBUFFER_ERROR;
217 goto err;
220 event_add(&bufev->ev_read, NULL);
222 len = EVBUFFER_LENGTH(bufev->input);
223 if (bufev->wm_read.low != 0 && len < bufev->wm_read.low)
224 return;
225 if (bufev->readcb != NULL)
226 (*bufev->readcb)(bufev, bufev->cbarg);
227 return;
229 retry:
230 event_add(&bufev->ev_read, NULL);
231 return;
233 err:
234 (*bufev->errorcb)(bufev, what, bufev->cbarg);
237 static void
238 tls_writecb(int fd, short event, void *d)
240 struct bufferevent *bufev = d;
241 ssize_t ret;
242 size_t len;
243 short what = EVBUFFER_WRITE;
244 void *data;
246 if (event == EV_TIMEOUT) {
247 what |= EVBUFFER_TIMEOUT;
248 goto err;
251 len = EVBUFFER_LENGTH(bufev->output);
252 if (len != 0) {
253 data = EVBUFFER_DATA(bufev->output);
255 #if DEBUG_PACKETS
256 hexdump("outgoing msg", data, len);
257 #endif
259 switch (ret = tls_write(ctx, data, len)) {
260 case TLS_WANT_POLLIN:
261 case TLS_WANT_POLLOUT:
262 goto retry;
263 case -1:
264 what |= EVBUFFER_ERROR;
265 goto err;
267 evbuffer_drain(bufev->output, ret);
270 if (EVBUFFER_LENGTH(bufev->output) != 0)
271 event_add(&bufev->ev_write, NULL);
273 if (bufev->writecb != NULL &&
274 EVBUFFER_LENGTH(bufev->output) <= bufev->wm_write.low)
275 (*bufev->writecb)(bufev, bufev->cbarg);
276 return;
278 retry:
279 event_add(&bufev->ev_write, NULL);
280 return;
281 err:
282 (*bufev->errorcb)(bufev, what, bufev->cbarg);
285 static void
286 client_read(struct bufferevent *bev, void *d)
288 struct evbuffer *src = EVBUFFER_INPUT(bev);
289 uint32_t len;
290 uint8_t *data;
292 for (;;) {
293 if (EVBUFFER_LENGTH(src) < sizeof(len))
294 return;
296 data = EVBUFFER_DATA(src);
298 memcpy(&len, data, sizeof(len));
299 len = le32toh(len);
301 if (len < HEADERSIZE)
302 fatal("incoming message is too small! (%d bytes)",
303 len);
305 if (len > EVBUFFER_LENGTH(src))
306 return;
308 #if DEBUG_PACKETS
309 hexdump("incoming msg", data, len);
310 #endif
312 handle_9p(data, len);
313 evbuffer_drain(src, len);
317 static void
318 client_write(struct bufferevent *bev, void *data)
320 return; /* nothing to do */
323 static void
324 client_error(struct bufferevent *bev, short err, void *data)
326 if (err & EVBUFFER_ERROR)
327 fatal("buffer event error");
329 if (err & EVBUFFER_EOF) {
330 clr();
331 log_info("EOF");
332 event_loopbreak();
333 return;
336 clr();
337 log_warnx("unknown event error");
338 event_loopbreak();
341 static void
342 repl_read(struct bufferevent *bev, void *d)
344 size_t len;
345 int argc;
346 const char *argv[10], **ap;
347 char *line;
349 line = evbuffer_readln(bev->input, &len, EVBUFFER_EOL_LF);
350 if (line == NULL)
351 return;
353 for (argc = 0, ap = argv; ap < &argv[9] &&
354 (*ap = strsep(&line, " \t")) != NULL;) {
355 if (**ap != '\0')
356 ap++, argc++;
359 clr();
360 excmd(argv, argc);
361 prompt();
363 free(line);
366 static void
367 repl_error(struct bufferevent *bev, short error, void *d)
369 fatalx("an error occurred");
372 static void
373 write_hdr(uint32_t len, uint8_t type, uint16_t tag)
375 len += HEADERSIZE;
377 log_debug("enqueuing a packet; len=%"PRIu32" type=%d[%s] tag=%d",
378 len, type, pp_msg_type(type), tag);
380 len = htole32(len);
381 /* type is one byte, no endiannes issues */
382 tag = htole16(tag);
384 bufferevent_write(bev, &len, sizeof(len));
385 bufferevent_write(bev, &type, sizeof(type));
386 bufferevent_write(bev, &tag, sizeof(tag));
389 static void
390 write_hdr_auto(uint32_t len, uint8_t type)
392 static uint16_t tag = 0;
394 if (++tag == NOTAG)
395 ++tag;
397 write_hdr(len, type, tag);
400 static void
401 write_str(uint16_t len, const char *str)
403 uint16_t l = len;
405 len = htole16(len);
406 bufferevent_write(bev, &len, sizeof(len));
407 bufferevent_write(bev, str, l);
410 static void
411 write_str_auto(const char *str)
413 write_str(strlen(str), str);
416 static void
417 write_fid(uint32_t fid)
419 fid = htole32(fid);
420 bufferevent_write(bev, &fid, sizeof(fid));
423 static void
424 write_tag(uint16_t tag)
426 tag = htole16(tag);
427 bufferevent_write(bev, &tag, sizeof(tag));
430 /* version [version-str] */
431 static void
432 excmd_version(const char **argv, int argc)
434 uint32_t len, msize;
435 uint16_t sl;
436 const char *s;
438 s = VERSION9P;
439 if (argc == 2)
440 s = argv[1];
442 sl = strlen(s);
444 /* msize[4] version[s] */
445 len = 4 + sizeof(sl) + sl;
446 write_hdr(len, Tversion, NOTAG);
448 msize = htole32(MSIZE9P);
449 bufferevent_write(bev, &msize, sizeof(msize));
451 write_str(sl, s);
454 /* attach fid uname aname */
455 static void
456 excmd_attach(const char **argv, int argc)
458 uint32_t len, fid;
459 uint16_t sl, tl;
460 const char *s, *t, *errstr;
462 if (argc != 4)
463 goto usage;
465 fid = strtonum(argv[1], 0, UINT32_MAX, &errstr);
466 if (errstr != NULL) {
467 log_warnx("fid is %s: %s", errstr, argv[1]);
468 return;
471 s = argv[2];
472 sl = strlen(s);
473 t = argv[3];
474 tl = strlen(t);
476 /* fid[4] afid[4] uname[s] aname[s] */
477 len = 4 + 4 + sizeof(sl) + sl + sizeof(tl) + tl;
478 write_hdr_auto(len, Tattach);
479 write_fid(fid);
480 write_fid(NOFID);
481 write_str(sl, s);
482 write_str(tl, t);
484 return;
486 usage:
487 log_warnx("usage: attach fid uname aname");
490 /* clunk fid */
491 static void
492 excmd_clunk(const char **argv, int argc)
494 uint32_t len, fid;
495 const char *errstr;
497 if (argc != 2)
498 goto usage;
500 fid = strtonum(argv[1], 0, UINT32_MAX, &errstr);
501 if (errstr != NULL) {
502 log_warnx("fid is %s: %s", errstr, argv[1]);
503 return;
506 /* fid[4] */
507 len = sizeof(fid);
508 write_hdr_auto(len, Tclunk);
509 write_fid(fid);
510 return;
512 usage:
513 log_warnx("usage: clunk fid");
516 /* flush oldtag */
517 static void
518 excmd_flush(const char **argv, int argc)
520 uint32_t len;
521 uint16_t oldtag;
522 const char *errstr;
524 if (argc != 2)
525 goto usage;
527 oldtag = strtonum(argv[1], 0, UINT16_MAX, &errstr);
528 if (errstr != NULL) {
529 log_warnx("oldtag is %s: %s", errstr, argv[1]);
530 return;
533 /* oldtag[2] */
534 len = sizeof(oldtag);
535 write_hdr_auto(len, Tflush);
536 write_tag(oldtag);
537 return;
539 usage:
540 log_warnx("usage: flush oldtag");
543 /* walk fid newfid wnames... */
544 static void
545 excmd_walk(const char **argv, int argc)
547 int i;
548 uint32_t len, fid, newfid;
549 const char *errstr;
551 if (argc < 3)
552 goto usage;
554 /* fid[4] newfid[4] nwname[2] nwname*(wname[s]) */
556 /* two bytes for wnames count */
557 len = sizeof(fid) + sizeof(newfid) + 2;
558 for (i = 3; i < argc; ++i)
559 len += 2 + strlen(argv[i]);
561 fid = strtonum(argv[1], 0, UINT32_MAX, &errstr);
562 if (errstr != NULL) {
563 log_warnx("fid is %s: %s", errstr, argv[1]);
564 return;
567 newfid = strtonum(argv[2], 0, UINT32_MAX, &errstr);
568 if (errstr != NULL) {
569 log_warnx("newfid is %s: %s", errstr, argv[1]);
570 return;
573 write_hdr_auto(len, Twalk);
574 write_fid(fid);
575 write_fid(newfid);
576 write_tag(argc - 3);
577 for (i = 3; i < argc; ++i)
578 write_str_auto(argv[i]);
580 return;
582 usage:
583 log_warnx("usage: walk fid newfid wnames...");
586 static void
587 excmd(const char **argv, int argc)
589 struct cmd {
590 const char *name;
591 void (*fn)(const char **, int);
592 } cmds[] = {
593 {"version", excmd_version},
594 {"attach", excmd_attach},
595 {"clunk", excmd_clunk},
596 {"flush", excmd_flush},
597 {"walk", excmd_walk},
598 };
599 size_t i;
601 if (argc == 0)
602 return;
604 for (i = 0; i < sizeof(cmds)/sizeof(cmds[0]); ++i) {
605 if (!strcmp(cmds[i].name, argv[0])) {
606 cmds[i].fn(argv, argc);
607 return;
611 log_warnx("Unknown command %s", *argv);
614 static const char *
615 pp_qid_type(uint8_t type)
617 switch (type) {
618 case QTDIR: return "dir";
619 case QTAPPEND: return "append-only";
620 case QTEXCL: return "exclusive";
621 case QTMOUNT: return "mounted-channel";
622 case QTAUTH: return "authentication";
623 case QTTMP: return "non-backed-up";
624 case QTSYMLINK: return "symlink";
625 case QTFILE: return "file";
628 return "unknown";
631 static void
632 pp_qid(const uint8_t *d, uint32_t len)
634 uint64_t path;
635 uint32_t vers;
636 uint8_t type;
638 if (len < 13) {
639 printf("invalid");
640 return;
643 type = *d++;
645 memcpy(&vers, d, sizeof(vers));
646 d += sizeof(vers);
647 vers = le64toh(vers);
649 memcpy(&path, d, sizeof(path));
650 d += sizeof(path);
651 path = le64toh(path);
653 printf("qid{path=%"PRIu64" version=%"PRIu32" type=0x%x\"%s\"}",
654 path, vers, type, pp_qid_type(type));
657 static void
658 pp_msg(uint32_t len, uint8_t type, uint16_t tag, const uint8_t *d)
660 uint32_t msize;
661 uint16_t slen;
663 printf("len=%"PRIu32" type=%d[%s] tag=0x%x[%d] ", len,
664 type, pp_msg_type(type), tag, tag);
666 len -= HEADERSIZE;
668 switch (type) {
669 case Rversion:
670 if (len < 6) {
671 printf("invalid: not enough space for msize "
672 "and version provided.");
673 break;
676 memcpy(&msize, d, sizeof(msize));
677 d += sizeof(msize);
678 len -= sizeof(msize);
679 msize = le32toh(msize);
681 memcpy(&slen, d, sizeof(slen));
682 d += sizeof(slen);
683 len -= sizeof(slen);
684 slen = le16toh(slen);
686 if (len != slen) {
687 printf("invalid: version string length doesn't "
688 "match. Got %d; want %d", slen, len);
689 break;
692 printf("msize=%"PRIu32" version[%"PRIu16"]=\"",
693 msize, slen);
694 fwrite(d, 1, slen, stdout);
695 printf("\"");
697 break;
699 case Rattach:
700 pp_qid(d, len);
701 break;
703 case Rclunk:
704 if (len != 0)
705 printf("invalid Rclunk: %"PRIu32" extra bytes", len);
706 break;
708 case Rflush:
709 if (len != 0)
710 printf("invalid Rflush: %"PRIu32" extra bytes", len);
711 break;
713 case Rwalk:
714 if (len < 2) {
715 printf("invaild Rwalk: less than two bytes (%d)",
716 (int)len);
717 break;
720 memcpy(&slen, d, sizeof(slen));
721 d += sizeof(slen);
722 len -= sizeof(slen);
723 slen = le16toh(slen);
725 if (len != QIDSIZE * slen) {
726 printf("invalid Rwalk: wanted %d bytes for %d qids "
727 "but got %"PRIu32" bytes instead",
728 QIDSIZE*slen, slen, len);
729 break;
732 printf("nwqid=%"PRIu16, slen);
734 for (; slen != 0; slen--) {
735 printf(" ");
736 pp_qid(d, len);
737 d += QIDSIZE;
738 len -= QIDSIZE;
741 break;
743 case Rerror:
744 memcpy(&slen, d, sizeof(slen));
745 d += sizeof(slen);
746 len -= sizeof(slen);
747 slen = le16toh(slen);
749 if (slen != len) {
750 printf("invalid: error string length doesn't "
751 "match. Got %d; want %d", slen, len);
752 break;
755 printf("error=\"");
756 fwrite(d, 1, slen, stdout);
757 printf("\"");
759 break;
761 default:
762 printf("unknown command type");
765 printf("\n");
768 static void
769 handle_9p(const uint8_t *data, size_t size)
771 uint32_t len;
772 uint16_t tag;
773 uint8_t type;
775 assert(size >= HEADERSIZE);
777 memcpy(&len, data, sizeof(len));
778 data += sizeof(len);
780 memcpy(&type, data, sizeof(type));
781 data += sizeof(type);
783 memcpy(&tag, data, sizeof(tag));
784 data += sizeof(tag);
786 len = le32toh(len);
787 /* type is one byte long, no endianness issues */
788 tag = le16toh(tag);
790 clr();
791 pp_msg(len, type, tag, data);
792 prompt();
795 static void
796 clr(void)
798 printf("\r");
799 fflush(stdout);
802 static void
803 prompt(void)
805 printf("%s", PROMPT);
806 fflush(stdout);
809 int
810 main(int argc, char **argv)
812 int ch, sock, handshake;
813 struct event ev_sigint, ev_sigterm;
815 signal(SIGPIPE, SIG_IGN);
817 while ((ch = getopt(argc, argv, "C:cH:hK:P:v")) != -1) {
818 switch (ch) {
819 case 'C':
820 crtpath = optarg;
821 break;
822 case 'c':
823 tls = 1;
824 break;
825 case 'H':
826 host = optarg;
827 break;
828 case 'h':
829 usage(0);
830 break;
831 case 'K':
832 keypath = optarg;
833 break;
834 case 'P':
835 port = optarg;
836 break;
837 case 'v':
838 verbose = 1;
839 break;
840 default:
841 usage(1);
845 if (host == NULL)
846 host = "localhost";
847 if (port == NULL)
848 port = "1337";
850 argc -= optind;
851 argv += optind;
853 if (argc != 0)
854 usage(1);
855 /* if (!tls || (crtpath != NULL || keypath != NULL)) */
856 /* usage(1); */
857 if (!tls)
858 errx(1, "must enable tls (for now)");
860 log_init(1, LOG_DAEMON);
861 log_setverbose(verbose);
862 log_procinit(getprogname());
864 if ((tlsconf = tls_config_new()) == NULL)
865 fatalx("tls_config_new");
866 tls_config_insecure_noverifycert(tlsconf);
867 tls_config_insecure_noverifyname(tlsconf);
868 if (tls_config_set_keypair_file(tlsconf, crtpath, keypath) == -1)
869 fatalx("can't load certs (%s, %s)", crtpath, keypath);
871 if ((ctx = tls_client()) == NULL)
872 fatal("tls_client");
873 if (tls_configure(ctx, tlsconf) == -1)
874 fatalx("tls_configure: %s", tls_error(ctx));
876 log_info("connecting to %s:%s...", host, port);
878 if ((sock = openconn()) == -1)
879 fatalx("can't connect to %s:%s", host, port);
881 if (tls_connect_socket(ctx, sock, host) == -1)
882 fatalx("tls_connect_socket: %s", tls_error(ctx));
884 for (handshake = 0; !handshake;) {
885 switch (tls_handshake(ctx)) {
886 case -1:
887 fatalx("tls_handshake: %s", tls_error(ctx));
888 case 0:
889 handshake = 1;
890 break;
894 log_info("connected!");
896 mark_nonblock(sock);
898 event_init();
900 signal_set(&ev_sigint, SIGINT, sig_handler, NULL);
901 signal_set(&ev_sigterm, SIGINT, sig_handler, NULL);
903 signal_add(&ev_sigint, NULL);
904 signal_add(&ev_sigterm, NULL);
906 bev = bufferevent_new(sock, client_read, client_write, client_error,
907 NULL);
908 if (bev == NULL)
909 fatal("bufferevent_new");
911 /* setup tls/io */
912 event_set(&bev->ev_read, sock, EV_READ, tls_readcb, bev);
913 event_set(&bev->ev_write, sock, EV_WRITE, tls_writecb, bev);
915 bufferevent_enable(bev, EV_READ|EV_WRITE);
917 mark_nonblock(0);
918 inbev = bufferevent_new(0, repl_read, NULL, repl_error, NULL);
919 bufferevent_enable(inbev, EV_READ);
921 prompt();
922 event_dispatch();
924 bufferevent_free(bev);
925 tls_free(ctx);
926 tls_config_free(tlsconf);
927 close(sock);
929 return 0;