Blob


1 /*
2 * Copyright (c) 2021, 2022 Omar Polo <op@omarpolo.com>
3 *
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15 */
17 #include <sys/types.h>
18 #include <sys/stat.h>
19 #include <sys/socket.h>
20 #include <sys/wait.h>
22 #include <ctype.h>
23 #include <errno.h>
24 #include <event.h>
25 #include <fcntl.h>
26 #include <limits.h>
27 #include <netdb.h>
28 #include <signal.h>
29 #include <stdio.h>
30 #include <stdlib.h>
31 #include <string.h>
32 #include <syslog.h>
33 #include <unistd.h>
35 #include "log.h"
37 #ifndef SSH_PATH
38 #define SSH_PATH "/usr/bin/ssh"
39 #endif
41 #define MAXSOCK 32
42 #define BACKOFF 1
43 #define RETRIES 16
45 #ifndef __OpenBSD__
46 #define pledge(p, e) 0
47 #define unveil(p, m) 0
48 #endif
50 const char *addr; /* our addr */
51 const char *ssh_tflag;
52 const char *ssh_dest;
54 char ssh_host[256];
55 char ssh_port[16];
57 struct event sockev[MAXSOCK];
58 int socks[MAXSOCK];
59 int nsock;
61 int debug;
62 int verbose;
64 struct event sighupev;
65 struct event sigintev;
66 struct event sigtermev;
67 struct event sigchldev;
68 struct event siginfoev;
70 struct timeval timeout = {600, 0}; /* 10 minutes */
71 struct event timeoutev;
73 pid_t ssh_pid = -1;
75 int conn;
77 struct conn {
78 int ntentative;
79 struct timeval retry;
80 struct event waitev;
81 int source;
82 struct bufferevent *sourcebev;
83 int to;
84 struct bufferevent *tobev;
85 };
87 static void
88 sig_handler(int sig, short event, void *data)
89 {
90 int status;
92 switch (sig) {
93 case SIGHUP:
94 case SIGINT:
95 case SIGTERM:
96 log_info("quitting");
97 event_loopbreak();
98 break;
99 case SIGCHLD:
100 if (waitpid(ssh_pid, &status, WNOHANG) == -1)
101 fatal("waitpid");
102 ssh_pid = -1;
103 break;
104 #ifdef SIGINFO
105 case SIGINFO:
106 #else
107 case SIGUSR1:
108 #endif
109 log_info("connections: %d", conn);
113 static int
114 spawn_ssh(void)
116 log_debug("spawning ssh");
118 switch (ssh_pid = fork()) {
119 case -1:
120 log_warnx("fork");
121 return -1;
122 case 0:
123 execl(SSH_PATH, "ssh", "-L", ssh_tflag, "-NTq", ssh_dest,
124 NULL);
125 fatal("exec");
126 default:
127 return 0;
131 static void
132 conn_free(struct conn *c)
134 if (c->sourcebev != NULL)
135 bufferevent_free(c->sourcebev);
136 if (c->tobev != NULL)
137 bufferevent_free(c->tobev);
139 if (evtimer_pending(&c->waitev, NULL))
140 evtimer_del(&c->waitev);
142 close(c->source);
143 if (c->to != -1)
144 close(c->to);
146 free(c);
149 static void
150 killing_time(int fd, short event, void *data)
152 if (ssh_pid == -1)
153 return;
155 log_debug("timeout expired, killing ssh (%d)", ssh_pid);
156 kill(ssh_pid, SIGTERM);
157 ssh_pid = -1;
160 static void
161 nopcb(struct bufferevent *bev, void *d)
163 return;
166 static void
167 sreadcb(struct bufferevent *bev, void *d)
169 struct conn *c = d;
171 bufferevent_write_buffer(c->tobev, EVBUFFER_INPUT(bev));
174 static void
175 treadcb(struct bufferevent *bev, void *d)
177 struct conn *c = d;
179 bufferevent_write_buffer(c->sourcebev, EVBUFFER_INPUT(bev));
182 static void
183 errcb(struct bufferevent *bev, short event, void *d)
185 struct conn *c = d;
187 log_info("closing connection (event=%x)", event);
189 conn_free(c);
191 if (--conn == 0) {
192 log_debug("scheduling ssh termination (%llds)",
193 (long long)timeout.tv_sec);
194 if (timeout.tv_sec != 0) {
195 evtimer_set(&timeoutev, killing_time, NULL);
196 evtimer_add(&timeoutev, &timeout);
201 static int
202 connect_to_ssh(void)
204 struct addrinfo hints, *res, *res0;
205 int r, saved_errno, sock;
206 const char *cause;
208 memset(&hints, 0, sizeof(hints));
209 hints.ai_family = AF_UNSPEC;
210 hints.ai_socktype = SOCK_STREAM;
212 r = getaddrinfo(ssh_host, ssh_port, &hints, &res0);
213 if (r != 0) {
214 log_warnx("getaddrinfo(\"%s\", \"%s\"): %s",
215 ssh_host, ssh_port, gai_strerror(r));
216 return -1;
219 for (res = res0; res; res = res->ai_next) {
220 sock = socket(res->ai_family, res->ai_socktype,
221 res->ai_protocol);
222 if (sock == -1) {
223 cause = "socket";
224 continue;
227 if (connect(sock, res->ai_addr, res->ai_addrlen) == -1) {
228 cause = "connect";
229 saved_errno = errno;
230 close(sock);
231 errno = saved_errno;
232 sock = -1;
233 continue;
236 break;
239 if (sock == -1)
240 log_warn("%s", cause);
242 freeaddrinfo(res0);
243 return sock;
246 static void
247 try_to_connect(int fd, short event, void *d)
249 struct conn *c = d;
251 /* ssh may have died in the meantime */
252 if (ssh_pid == -1) {
253 conn_free(c);
254 return;
257 c->ntentative++;
258 log_info("trying to connect to %s:%s (%d/%d)", ssh_host, ssh_port,
259 c->ntentative, RETRIES);
261 if ((c->to = connect_to_ssh()) == -1) {
262 if (c->ntentative == RETRIES) {
263 log_warnx("giving up connecting");
264 conn_free(c);
265 return;
268 evtimer_set(&c->waitev, try_to_connect, c);
269 evtimer_add(&c->waitev, &c->retry);
270 return;
273 log_info("connected!");
275 c->sourcebev = bufferevent_new(c->source, sreadcb, nopcb, errcb, c);
276 c->tobev = bufferevent_new(c->to, treadcb, nopcb, errcb, c);
277 if (c->sourcebev == NULL || c->tobev == NULL) {
278 log_warn("bufferevent_new");
279 conn_free(c);
280 return;
283 bufferevent_enable(c->sourcebev, EV_READ|EV_WRITE);
284 bufferevent_enable(c->tobev, EV_READ|EV_WRITE);
287 static void
288 do_accept(int fd, short event, void *data)
290 struct conn *c;
291 int s;
293 log_debug("incoming connection");
295 if ((s = accept(fd, NULL, 0)) == -1) {
296 log_warn("accept");
297 return;
300 if (ssh_pid == -1 && spawn_ssh() == -1) {
301 close(s);
302 return;
305 if ((c = calloc(1, sizeof(*c))) == NULL) {
306 log_warn("calloc");
307 close(s);
308 return;
311 conn++;
312 if (evtimer_pending(&timeoutev, NULL))
313 evtimer_del(&timeoutev);
315 c->source = s;
316 c->to = -1;
317 c->retry.tv_sec = BACKOFF;
318 evtimer_set(&c->waitev, try_to_connect, c);
319 evtimer_add(&c->waitev, &c->retry);
322 static const char *
323 copysec(const char *s, char *d, size_t len)
325 const char *c;
327 if ((c = strchr(s, ':')) == NULL)
328 return NULL;
329 if ((size_t)(c - s) >= len-1)
330 return NULL;
331 memset(d, 0, len);
332 memcpy(d, s, c - s);
333 return c;
336 static void
337 bind_socket(void)
339 struct addrinfo hints, *res, *res0;
340 int v, r, saved_errno;
341 char host[64];
342 const char *c, *h, *port, *cause;
344 if ((c = strchr(addr, ':')) == NULL) {
345 h = NULL;
346 port = addr;
347 } else {
348 if ((c = copysec(addr, host, sizeof(host))) == NULL)
349 fatalx("name too long: %s", addr);
351 h = host;
352 port = c+1;
355 memset(&hints, 0, sizeof(hints));
356 hints.ai_family = AF_UNSPEC;
357 hints.ai_socktype = SOCK_STREAM;
358 hints.ai_flags = AI_PASSIVE;
360 r = getaddrinfo(h, port, &hints, &res0);
361 if (r != 0)
362 fatalx("getaddrinfo(%s): %s", addr, gai_strerror(r));
364 for (res = res0; res && nsock < MAXSOCK; res = res->ai_next) {
365 socks[nsock] = socket(res->ai_family, res->ai_socktype,
366 res->ai_protocol);
367 if (socks[nsock] == -1) {
368 cause = "socket";
369 continue;
372 if (bind(socks[nsock], res->ai_addr, res->ai_addrlen) == -1) {
373 cause = "bind";
374 saved_errno = errno;
375 close(socks[nsock]);
376 errno = saved_errno;
377 continue;
380 v = 1;
381 if (setsockopt(socks[nsock], SOL_SOCKET, SO_REUSEADDR, &v,
382 sizeof(v)) == -1)
383 fatal("setsockopt(SO_REUSEADDR)");
385 v = 1;
386 if (setsockopt(socks[nsock], SOL_SOCKET, SO_REUSEPORT, &v,
387 sizeof(v)) == -1)
388 fatal("setsockopt(SO_REUSEPORT)");
390 listen(socks[nsock], 5);
392 nsock++;
394 if (nsock == 0)
395 fatal("%s", cause);
397 freeaddrinfo(res0);
400 static void
401 parse_sshaddr(void)
403 const char *c;
405 if (isdigit((unsigned char)*ssh_tflag)) {
406 strlcpy(ssh_host, "localhost", sizeof(ssh_host));
407 if (copysec(ssh_tflag, ssh_port, sizeof(ssh_port)) == NULL)
408 goto err;
409 return;
412 if ((c = copysec(ssh_tflag, ssh_host, sizeof(ssh_host))) == NULL)
413 goto err;
414 if (copysec(c+1, ssh_port, sizeof(ssh_port)) == NULL)
415 goto err;
416 return;
418 err:
419 fatalx("wrong value for -B");
422 static void __dead
423 usage(void)
425 fprintf(stderr, "usage: %s [-dv] -B sshaddr -b addr [-t timeout]"
426 " destination\n", getprogname());
427 exit(1);
430 int
431 main(int argc, char **argv)
433 int ch, i, fd;
434 const char *errstr;
435 struct stat sb;
437 /*
438 * Ensure we have fds 0-2 open so that we have no issue with
439 * calling bind_socket before daemon(3).
440 */
441 for (i = 0; i < 3; ++i) {
442 if (fstat(i, &sb) == -1) {
443 if ((fd = open("/dev/null", O_RDWR)) != -1) {
444 if (dup2(fd, i) == -1)
445 exit(1);
446 if (fd > i)
447 close(fd);
448 } else
449 exit(1);
453 log_init(1, LOG_DAEMON);
454 log_setverbose(1);
456 while ((ch = getopt(argc, argv, "B:b:dt:v")) != -1) {
457 switch (ch) {
458 case 'B':
459 ssh_tflag = optarg;
460 parse_sshaddr();
461 break;
462 case 'b':
463 addr = optarg;
464 break;
465 case 'd':
466 debug = 1;
467 break;
468 case 't':
469 timeout.tv_sec = strtonum(optarg, 0, INT_MAX, &errstr);
470 if (errstr != NULL)
471 fatalx("timeout is %s: %s", errstr, optarg);
472 break;
473 case 'v':
474 verbose = 1;
475 break;
476 default:
477 usage();
480 argc -= optind;
481 argv += optind;
483 if (argc != 1 || addr == NULL || ssh_tflag == NULL)
484 usage();
486 ssh_dest = argv[0];
488 bind_socket();
490 log_init(debug, LOG_DAEMON);
491 log_setverbose(verbose);
493 if (!debug)
494 daemon(1, 0);
496 signal(SIGPIPE, SIG_IGN);
498 event_init();
500 /* initialize the timer */
501 evtimer_set(&timeoutev, killing_time, NULL);
503 signal_set(&sighupev, SIGHUP, sig_handler, NULL);
504 signal_set(&sigintev, SIGINT, sig_handler, NULL);
505 signal_set(&sigtermev, SIGTERM, sig_handler, NULL);
506 signal_set(&sigchldev, SIGCHLD, sig_handler, NULL);
507 #ifdef SIGINFO
508 signal_set(&siginfoev, SIGINFO, sig_handler, NULL);
509 #else
510 signal_set(&siginfoev, SIGUSR1, sig_handler, NULL);
511 #endif
513 signal_add(&sighupev, NULL);
514 signal_add(&sigintev, NULL);
515 signal_add(&sigtermev, NULL);
516 signal_add(&sigchldev, NULL);
517 signal_add(&siginfoev, NULL);
519 for (i = 0; i < nsock; ++i) {
520 event_set(&sockev[i], socks[i], EV_READ|EV_PERSIST,
521 do_accept, NULL);
522 event_add(&sockev[i], NULL);
525 if (unveil(SSH_PATH, "x") == -1)
526 fatal("unveil(%s)", SSH_PATH);
528 /*
529 * dns, inet: bind the socket and connect to the childs.
530 * proc, exec: execute ssh on demand.
531 */
532 if (pledge("stdio dns inet proc exec", NULL) == -1)
533 fatal("pledge");
535 log_info("starting");
536 event_dispatch();
538 if (ssh_pid != -1)
539 kill(ssh_pid, SIGINT);
541 return 0;