Blob


1 /*
2 * Copyright (c) 2021, 2022 Omar Polo <op@omarpolo.com>
3 *
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15 */
17 #include "config.h"
19 #include <sys/types.h>
20 #include <sys/stat.h>
21 #include <sys/socket.h>
22 #include <sys/wait.h>
24 #include <ctype.h>
25 #include <errno.h>
26 #include <fcntl.h>
27 #include <limits.h>
28 #include <netdb.h>
29 #include <signal.h>
30 #include <stdio.h>
31 #include <stdlib.h>
32 #include <string.h>
33 #include <syslog.h>
34 #include <unistd.h>
36 #include "log.h"
38 #ifndef SSH_PATH
39 #define SSH_PATH "/usr/bin/ssh"
40 #endif
42 #define MAXSOCK 32
43 #define BACKOFF 1
44 #define RETRIES 16
46 #ifndef __OpenBSD__
47 #define pledge(p, e) 0
48 #define unveil(p, m) 0
49 #endif
51 const char *addr; /* our addr */
52 const char *ssh_tflag;
53 const char *ssh_dest;
55 char ssh_host[256];
56 char ssh_port[16];
58 struct event sockev[MAXSOCK];
59 int socks[MAXSOCK];
60 int nsock;
62 int debug;
63 int verbose;
65 struct event sighupev;
66 struct event sigintev;
67 struct event sigtermev;
68 struct event sigchldev;
69 struct event siginfoev;
71 struct timeval timeout = {600, 0}; /* 10 minutes */
72 struct event timeoutev;
74 pid_t ssh_pid = -1;
76 int conn;
78 struct conn {
79 int ntentative;
80 struct timeval retry;
81 struct event waitev;
82 int source;
83 struct bufferevent *sourcebev;
84 int to;
85 struct bufferevent *tobev;
86 };
88 static void
89 sig_handler(int sig, short event, void *data)
90 {
91 int status;
93 switch (sig) {
94 case SIGHUP:
95 case SIGINT:
96 case SIGTERM:
97 log_info("quitting");
98 event_loopbreak();
99 break;
100 case SIGCHLD:
101 if (waitpid(ssh_pid, &status, WNOHANG) == -1)
102 fatal("waitpid");
103 ssh_pid = -1;
104 break;
105 #ifdef SIGINFO
106 case SIGINFO:
107 #else
108 case SIGUSR1:
109 #endif
110 log_info("connections: %d", conn);
114 static int
115 spawn_ssh(void)
117 log_debug("spawning ssh");
119 switch (ssh_pid = fork()) {
120 case -1:
121 log_warnx("fork");
122 return -1;
123 case 0:
124 execl(SSH_PATH, "ssh", "-L", ssh_tflag, "-NTq", ssh_dest,
125 NULL);
126 fatal("exec");
127 default:
128 return 0;
132 static void
133 conn_free(struct conn *c)
135 if (c->sourcebev != NULL)
136 bufferevent_free(c->sourcebev);
137 if (c->tobev != NULL)
138 bufferevent_free(c->tobev);
140 if (evtimer_pending(&c->waitev, NULL))
141 evtimer_del(&c->waitev);
143 close(c->source);
144 if (c->to != -1)
145 close(c->to);
147 free(c);
150 static void
151 killing_time(int fd, short event, void *data)
153 if (ssh_pid == -1)
154 return;
156 log_debug("timeout expired, killing ssh (%d)", ssh_pid);
157 kill(ssh_pid, SIGTERM);
158 ssh_pid = -1;
161 static void
162 nopcb(struct bufferevent *bev, void *d)
164 return;
167 static void
168 sreadcb(struct bufferevent *bev, void *d)
170 struct conn *c = d;
172 bufferevent_write_buffer(c->tobev, EVBUFFER_INPUT(bev));
175 static void
176 treadcb(struct bufferevent *bev, void *d)
178 struct conn *c = d;
180 bufferevent_write_buffer(c->sourcebev, EVBUFFER_INPUT(bev));
183 static void
184 errcb(struct bufferevent *bev, short event, void *d)
186 struct conn *c = d;
188 log_info("closing connection (event=%x)", event);
190 conn_free(c);
192 if (--conn == 0) {
193 log_debug("scheduling ssh termination (%llds)",
194 (long long)timeout.tv_sec);
195 if (timeout.tv_sec != 0) {
196 evtimer_set(&timeoutev, killing_time, NULL);
197 evtimer_add(&timeoutev, &timeout);
202 static int
203 connect_to_ssh(void)
205 struct addrinfo hints, *res, *res0;
206 int r, saved_errno, sock;
207 const char *cause;
209 memset(&hints, 0, sizeof(hints));
210 hints.ai_family = AF_UNSPEC;
211 hints.ai_socktype = SOCK_STREAM;
213 r = getaddrinfo(ssh_host, ssh_port, &hints, &res0);
214 if (r != 0) {
215 log_warnx("getaddrinfo(\"%s\", \"%s\"): %s",
216 ssh_host, ssh_port, gai_strerror(r));
217 return -1;
220 for (res = res0; res; res = res->ai_next) {
221 sock = socket(res->ai_family, res->ai_socktype,
222 res->ai_protocol);
223 if (sock == -1) {
224 cause = "socket";
225 continue;
228 if (connect(sock, res->ai_addr, res->ai_addrlen) == -1) {
229 cause = "connect";
230 saved_errno = errno;
231 close(sock);
232 errno = saved_errno;
233 sock = -1;
234 continue;
237 break;
240 if (sock == -1)
241 log_warn("%s", cause);
243 freeaddrinfo(res0);
244 return sock;
247 static void
248 try_to_connect(int fd, short event, void *d)
250 struct conn *c = d;
252 /* ssh may have died in the meantime */
253 if (ssh_pid == -1) {
254 conn_free(c);
255 return;
258 c->ntentative++;
259 log_info("trying to connect to %s:%s (%d/%d)", ssh_host, ssh_port,
260 c->ntentative, RETRIES);
262 if ((c->to = connect_to_ssh()) == -1) {
263 if (c->ntentative == RETRIES) {
264 log_warnx("giving up connecting");
265 conn_free(c);
266 return;
269 evtimer_set(&c->waitev, try_to_connect, c);
270 evtimer_add(&c->waitev, &c->retry);
271 return;
274 log_info("connected!");
276 c->sourcebev = bufferevent_new(c->source, sreadcb, nopcb, errcb, c);
277 c->tobev = bufferevent_new(c->to, treadcb, nopcb, errcb, c);
278 if (c->sourcebev == NULL || c->tobev == NULL) {
279 log_warn("bufferevent_new");
280 conn_free(c);
281 return;
284 bufferevent_enable(c->sourcebev, EV_READ|EV_WRITE);
285 bufferevent_enable(c->tobev, EV_READ|EV_WRITE);
288 static void
289 do_accept(int fd, short event, void *data)
291 struct conn *c;
292 int s;
294 log_debug("incoming connection");
296 if ((s = accept(fd, NULL, 0)) == -1) {
297 log_warn("accept");
298 return;
301 if (ssh_pid == -1 && spawn_ssh() == -1) {
302 close(s);
303 return;
306 if ((c = calloc(1, sizeof(*c))) == NULL) {
307 log_warn("calloc");
308 close(s);
309 return;
312 conn++;
313 if (evtimer_pending(&timeoutev, NULL))
314 evtimer_del(&timeoutev);
316 c->source = s;
317 c->to = -1;
318 c->retry.tv_sec = BACKOFF;
319 evtimer_set(&c->waitev, try_to_connect, c);
320 evtimer_add(&c->waitev, &c->retry);
323 static const char *
324 copysec(const char *s, char *d, size_t len)
326 const char *c;
328 if ((c = strchr(s, ':')) == NULL)
329 return NULL;
330 if ((size_t)(c - s) >= len-1)
331 return NULL;
332 memset(d, 0, len);
333 memcpy(d, s, c - s);
334 return c;
337 static void
338 bind_socket(void)
340 struct addrinfo hints, *res, *res0;
341 int v, r, saved_errno;
342 char host[64];
343 const char *c, *h, *port, *cause;
345 if ((c = strchr(addr, ':')) == NULL) {
346 h = NULL;
347 port = addr;
348 } else {
349 if ((c = copysec(addr, host, sizeof(host))) == NULL)
350 fatalx("name too long: %s", addr);
352 h = host;
353 port = c+1;
356 memset(&hints, 0, sizeof(hints));
357 hints.ai_family = AF_UNSPEC;
358 hints.ai_socktype = SOCK_STREAM;
359 hints.ai_flags = AI_PASSIVE;
361 r = getaddrinfo(h, port, &hints, &res0);
362 if (r != 0)
363 fatalx("getaddrinfo(%s): %s", addr, gai_strerror(r));
365 for (res = res0; res && nsock < MAXSOCK; res = res->ai_next) {
366 socks[nsock] = socket(res->ai_family, res->ai_socktype,
367 res->ai_protocol);
368 if (socks[nsock] == -1) {
369 cause = "socket";
370 continue;
373 if (bind(socks[nsock], res->ai_addr, res->ai_addrlen) == -1) {
374 cause = "bind";
375 saved_errno = errno;
376 close(socks[nsock]);
377 errno = saved_errno;
378 continue;
381 v = 1;
382 if (setsockopt(socks[nsock], SOL_SOCKET, SO_REUSEADDR, &v,
383 sizeof(v)) == -1)
384 fatal("setsockopt(SO_REUSEADDR)");
386 v = 1;
387 if (setsockopt(socks[nsock], SOL_SOCKET, SO_REUSEPORT, &v,
388 sizeof(v)) == -1)
389 fatal("setsockopt(SO_REUSEPORT)");
391 listen(socks[nsock], 5);
393 nsock++;
395 if (nsock == 0)
396 fatal("%s", cause);
398 freeaddrinfo(res0);
401 static void
402 parse_sshaddr(void)
404 const char *c;
406 if (isdigit((unsigned char)*ssh_tflag)) {
407 strlcpy(ssh_host, "localhost", sizeof(ssh_host));
408 if (copysec(ssh_tflag, ssh_port, sizeof(ssh_port)) == NULL)
409 goto err;
410 return;
413 if ((c = copysec(ssh_tflag, ssh_host, sizeof(ssh_host))) == NULL)
414 goto err;
415 if (copysec(c+1, ssh_port, sizeof(ssh_port)) == NULL)
416 goto err;
417 return;
419 err:
420 fatalx("wrong value for -B");
423 static void __dead
424 usage(void)
426 fprintf(stderr, "usage: %s [-dv] -B sshaddr -b addr [-t timeout]"
427 " destination\n", getprogname());
428 exit(1);
431 int
432 main(int argc, char **argv)
434 int ch, i, fd;
435 const char *errstr;
436 struct stat sb;
438 /*
439 * Ensure we have fds 0-2 open so that we have no issue with
440 * calling bind_socket before daemon(3).
441 */
442 for (i = 0; i < 3; ++i) {
443 if (fstat(i, &sb) == -1) {
444 if ((fd = open("/dev/null", O_RDWR)) != -1) {
445 if (dup2(fd, i) == -1)
446 exit(1);
447 if (fd > i)
448 close(fd);
449 } else
450 exit(1);
454 log_init(1, LOG_DAEMON);
455 log_setverbose(1);
457 while ((ch = getopt(argc, argv, "B:b:dt:v")) != -1) {
458 switch (ch) {
459 case 'B':
460 ssh_tflag = optarg;
461 parse_sshaddr();
462 break;
463 case 'b':
464 addr = optarg;
465 break;
466 case 'd':
467 debug = 1;
468 break;
469 case 't':
470 timeout.tv_sec = strtonum(optarg, 0, INT_MAX, &errstr);
471 if (errstr != NULL)
472 fatalx("timeout is %s: %s", errstr, optarg);
473 break;
474 case 'v':
475 verbose = 1;
476 break;
477 default:
478 usage();
481 argc -= optind;
482 argv += optind;
484 if (argc != 1 || addr == NULL || ssh_tflag == NULL)
485 usage();
487 ssh_dest = argv[0];
489 bind_socket();
491 log_init(debug, LOG_DAEMON);
492 log_setverbose(verbose);
494 if (!debug)
495 daemon(1, 0);
497 signal(SIGPIPE, SIG_IGN);
499 event_init();
501 /* initialize the timer */
502 evtimer_set(&timeoutev, killing_time, NULL);
504 signal_set(&sighupev, SIGHUP, sig_handler, NULL);
505 signal_set(&sigintev, SIGINT, sig_handler, NULL);
506 signal_set(&sigtermev, SIGTERM, sig_handler, NULL);
507 signal_set(&sigchldev, SIGCHLD, sig_handler, NULL);
508 #ifdef SIGINFO
509 signal_set(&siginfoev, SIGINFO, sig_handler, NULL);
510 #else
511 signal_set(&siginfoev, SIGUSR1, sig_handler, NULL);
512 #endif
514 signal_add(&sighupev, NULL);
515 signal_add(&sigintev, NULL);
516 signal_add(&sigtermev, NULL);
517 signal_add(&sigchldev, NULL);
518 signal_add(&siginfoev, NULL);
520 for (i = 0; i < nsock; ++i) {
521 event_set(&sockev[i], socks[i], EV_READ|EV_PERSIST,
522 do_accept, NULL);
523 event_add(&sockev[i], NULL);
526 if (unveil(SSH_PATH, "x") == -1)
527 fatal("unveil(%s)", SSH_PATH);
529 /*
530 * dns, inet: bind the socket and connect to the childs.
531 * proc, exec: execute ssh on demand.
532 */
533 if (pledge("stdio dns inet proc exec", NULL) == -1)
534 fatal("pledge");
536 log_info("starting");
537 event_dispatch();
539 if (ssh_pid != -1)
540 kill(ssh_pid, SIGINT);
542 return 0;