Blob


1 /*
2 * Copyright (c) 2021, 2022 Omar Polo <op@omarpolo.com>
3 *
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15 */
17 #include "config.h"
19 #include <sys/types.h>
20 #include <sys/stat.h>
21 #include <sys/socket.h>
22 #include <sys/wait.h>
24 #include <ctype.h>
25 #include <errno.h>
26 #include <fcntl.h>
27 #include <limits.h>
28 #include <netdb.h>
29 #include <signal.h>
30 #include <stdio.h>
31 #include <stdlib.h>
32 #include <string.h>
33 #include <syslog.h>
34 #include <unistd.h>
36 #include "log.h"
38 #define MAXSOCK 32
39 #define BACKOFF 1
40 #define RETRIES 16
42 #ifndef __OpenBSD__
43 #define pledge(p, e) 0
44 #define unveil(p, m) 0
45 #endif
47 const char *addr; /* our addr */
48 const char *ssh_tflag;
49 const char *ssh_dest;
51 char ssh_host[256];
52 char ssh_port[16];
54 struct event sockev[MAXSOCK];
55 int socks[MAXSOCK];
56 int nsock;
58 int debug;
59 int verbose;
61 struct event sighupev;
62 struct event sigintev;
63 struct event sigtermev;
64 struct event sigchldev;
65 struct event siginfoev;
67 struct timeval timeout = {600, 0}; /* 10 minutes */
68 struct event timeoutev;
70 pid_t ssh_pid = -1;
72 int conn;
74 struct conn {
75 int ntentative;
76 struct timeval retry;
77 struct event waitev;
78 int source;
79 struct bufferevent *sourcebev;
80 int to;
81 struct bufferevent *tobev;
82 };
84 static void
85 sig_handler(int sig, short event, void *data)
86 {
87 int status;
89 switch (sig) {
90 case SIGHUP:
91 case SIGINT:
92 case SIGTERM:
93 log_info("quitting");
94 event_loopbreak();
95 break;
96 case SIGCHLD:
97 if (waitpid(ssh_pid, &status, WNOHANG) == -1)
98 fatal("waitpid");
99 ssh_pid = -1;
100 break;
101 #ifdef SIGINFO
102 case SIGINFO:
103 #else
104 case SIGUSR1:
105 #endif
106 log_info("connections: %d", conn);
110 static int
111 spawn_ssh(void)
113 log_debug("spawning ssh");
115 switch (ssh_pid = fork()) {
116 case -1:
117 log_warnx("fork");
118 return -1;
119 case 0:
120 execl(SSH_PROG, "ssh", "-L", ssh_tflag, "-NTq", ssh_dest,
121 NULL);
122 fatal("exec");
123 default:
124 return 0;
128 static void
129 conn_free(struct conn *c)
131 if (c->sourcebev != NULL)
132 bufferevent_free(c->sourcebev);
133 if (c->tobev != NULL)
134 bufferevent_free(c->tobev);
136 if (evtimer_pending(&c->waitev, NULL))
137 evtimer_del(&c->waitev);
139 close(c->source);
140 if (c->to != -1)
141 close(c->to);
143 free(c);
146 static void
147 killing_time(int fd, short event, void *data)
149 if (ssh_pid == -1)
150 return;
152 log_debug("timeout expired, killing ssh (%d)", ssh_pid);
153 kill(ssh_pid, SIGTERM);
154 ssh_pid = -1;
157 static void
158 nopcb(struct bufferevent *bev, void *d)
160 return;
163 static void
164 sreadcb(struct bufferevent *bev, void *d)
166 struct conn *c = d;
168 bufferevent_write_buffer(c->tobev, EVBUFFER_INPUT(bev));
171 static void
172 treadcb(struct bufferevent *bev, void *d)
174 struct conn *c = d;
176 bufferevent_write_buffer(c->sourcebev, EVBUFFER_INPUT(bev));
179 static void
180 errcb(struct bufferevent *bev, short event, void *d)
182 struct conn *c = d;
184 log_info("closing connection (event=%x)", event);
186 conn_free(c);
188 if (--conn == 0) {
189 log_debug("scheduling ssh termination (%llds)",
190 (long long)timeout.tv_sec);
191 if (timeout.tv_sec != 0) {
192 evtimer_set(&timeoutev, killing_time, NULL);
193 evtimer_add(&timeoutev, &timeout);
198 static int
199 connect_to_ssh(void)
201 struct addrinfo hints, *res, *res0;
202 int r, saved_errno, sock;
203 const char *cause;
205 memset(&hints, 0, sizeof(hints));
206 hints.ai_family = AF_UNSPEC;
207 hints.ai_socktype = SOCK_STREAM;
209 r = getaddrinfo(ssh_host, ssh_port, &hints, &res0);
210 if (r != 0) {
211 log_warnx("getaddrinfo(\"%s\", \"%s\"): %s",
212 ssh_host, ssh_port, gai_strerror(r));
213 return -1;
216 for (res = res0; res; res = res->ai_next) {
217 sock = socket(res->ai_family, res->ai_socktype,
218 res->ai_protocol);
219 if (sock == -1) {
220 cause = "socket";
221 continue;
224 if (connect(sock, res->ai_addr, res->ai_addrlen) == -1) {
225 cause = "connect";
226 saved_errno = errno;
227 close(sock);
228 errno = saved_errno;
229 sock = -1;
230 continue;
233 break;
236 if (sock == -1)
237 log_warn("%s", cause);
239 freeaddrinfo(res0);
240 return sock;
243 static void
244 try_to_connect(int fd, short event, void *d)
246 struct conn *c = d;
248 /* ssh may have died in the meantime */
249 if (ssh_pid == -1) {
250 conn_free(c);
251 return;
254 c->ntentative++;
255 log_info("trying to connect to %s:%s (%d/%d)", ssh_host, ssh_port,
256 c->ntentative, RETRIES);
258 if ((c->to = connect_to_ssh()) == -1) {
259 if (c->ntentative == RETRIES) {
260 log_warnx("giving up connecting");
261 conn_free(c);
262 return;
265 evtimer_set(&c->waitev, try_to_connect, c);
266 evtimer_add(&c->waitev, &c->retry);
267 return;
270 log_info("connected!");
272 c->sourcebev = bufferevent_new(c->source, sreadcb, nopcb, errcb, c);
273 c->tobev = bufferevent_new(c->to, treadcb, nopcb, errcb, c);
274 if (c->sourcebev == NULL || c->tobev == NULL) {
275 log_warn("bufferevent_new");
276 conn_free(c);
277 return;
280 bufferevent_enable(c->sourcebev, EV_READ|EV_WRITE);
281 bufferevent_enable(c->tobev, EV_READ|EV_WRITE);
284 static void
285 do_accept(int fd, short event, void *data)
287 struct conn *c;
288 int s;
290 log_debug("incoming connection");
292 if ((s = accept(fd, NULL, 0)) == -1) {
293 log_warn("accept");
294 return;
297 if (ssh_pid == -1 && spawn_ssh() == -1) {
298 close(s);
299 return;
302 if ((c = calloc(1, sizeof(*c))) == NULL) {
303 log_warn("calloc");
304 close(s);
305 return;
308 conn++;
309 if (evtimer_pending(&timeoutev, NULL))
310 evtimer_del(&timeoutev);
312 c->source = s;
313 c->to = -1;
314 c->retry.tv_sec = BACKOFF;
315 evtimer_set(&c->waitev, try_to_connect, c);
316 evtimer_add(&c->waitev, &c->retry);
319 static const char *
320 copysec(const char *s, char *d, size_t len)
322 const char *c;
324 if ((c = strchr(s, ':')) == NULL)
325 return NULL;
326 if ((size_t)(c - s) >= len-1)
327 return NULL;
328 memset(d, 0, len);
329 memcpy(d, s, c - s);
330 return c;
333 static void
334 bind_socket(void)
336 struct addrinfo hints, *res, *res0;
337 int v, r, saved_errno;
338 char host[64];
339 const char *c, *h, *port, *cause;
341 if ((c = strchr(addr, ':')) == NULL) {
342 h = NULL;
343 port = addr;
344 } else {
345 if ((c = copysec(addr, host, sizeof(host))) == NULL)
346 fatalx("name too long: %s", addr);
348 h = host;
349 port = c+1;
352 memset(&hints, 0, sizeof(hints));
353 hints.ai_family = AF_UNSPEC;
354 hints.ai_socktype = SOCK_STREAM;
355 hints.ai_flags = AI_PASSIVE;
357 r = getaddrinfo(h, port, &hints, &res0);
358 if (r != 0)
359 fatalx("getaddrinfo(%s): %s", addr, gai_strerror(r));
361 for (res = res0; res && nsock < MAXSOCK; res = res->ai_next) {
362 socks[nsock] = socket(res->ai_family, res->ai_socktype,
363 res->ai_protocol);
364 if (socks[nsock] == -1) {
365 cause = "socket";
366 continue;
369 if (bind(socks[nsock], res->ai_addr, res->ai_addrlen) == -1) {
370 cause = "bind";
371 saved_errno = errno;
372 close(socks[nsock]);
373 errno = saved_errno;
374 continue;
377 v = 1;
378 if (setsockopt(socks[nsock], SOL_SOCKET, SO_REUSEADDR, &v,
379 sizeof(v)) == -1)
380 fatal("setsockopt(SO_REUSEADDR)");
382 v = 1;
383 if (setsockopt(socks[nsock], SOL_SOCKET, SO_REUSEPORT, &v,
384 sizeof(v)) == -1)
385 fatal("setsockopt(SO_REUSEPORT)");
387 listen(socks[nsock], 5);
389 nsock++;
391 if (nsock == 0)
392 fatal("%s", cause);
394 freeaddrinfo(res0);
397 static void
398 parse_sshaddr(void)
400 const char *c;
402 if (isdigit((unsigned char)*ssh_tflag)) {
403 strlcpy(ssh_host, "localhost", sizeof(ssh_host));
404 if (copysec(ssh_tflag, ssh_port, sizeof(ssh_port)) == NULL)
405 goto err;
406 return;
409 if ((c = copysec(ssh_tflag, ssh_host, sizeof(ssh_host))) == NULL)
410 goto err;
411 if (copysec(c+1, ssh_port, sizeof(ssh_port)) == NULL)
412 goto err;
413 return;
415 err:
416 fatalx("wrong value for -B");
419 static void __dead
420 usage(void)
422 fprintf(stderr, "usage: %s [-dv] -B sshaddr -b addr [-t timeout]"
423 " destination\n", getprogname());
424 exit(1);
427 int
428 main(int argc, char **argv)
430 int ch, i, fd;
431 const char *errstr;
432 struct stat sb;
434 /*
435 * Ensure we have fds 0-2 open so that we have no issue with
436 * calling bind_socket before daemon(3).
437 */
438 for (i = 0; i < 3; ++i) {
439 if (fstat(i, &sb) == -1) {
440 if ((fd = open("/dev/null", O_RDWR)) != -1) {
441 if (dup2(fd, i) == -1)
442 exit(1);
443 if (fd > i)
444 close(fd);
445 } else
446 exit(1);
450 log_init(1, LOG_DAEMON);
451 log_setverbose(1);
453 while ((ch = getopt(argc, argv, "B:b:dt:v")) != -1) {
454 switch (ch) {
455 case 'B':
456 ssh_tflag = optarg;
457 parse_sshaddr();
458 break;
459 case 'b':
460 addr = optarg;
461 break;
462 case 'd':
463 debug = 1;
464 break;
465 case 't':
466 timeout.tv_sec = strtonum(optarg, 0, INT_MAX, &errstr);
467 if (errstr != NULL)
468 fatalx("timeout is %s: %s", errstr, optarg);
469 break;
470 case 'v':
471 verbose = 1;
472 break;
473 default:
474 usage();
477 argc -= optind;
478 argv += optind;
480 if (argc != 1 || addr == NULL || ssh_tflag == NULL)
481 usage();
483 ssh_dest = argv[0];
485 bind_socket();
487 log_init(debug, LOG_DAEMON);
488 log_setverbose(verbose);
490 if (!debug)
491 daemon(1, 0);
493 signal(SIGPIPE, SIG_IGN);
495 event_init();
497 /* initialize the timer */
498 evtimer_set(&timeoutev, killing_time, NULL);
500 signal_set(&sighupev, SIGHUP, sig_handler, NULL);
501 signal_set(&sigintev, SIGINT, sig_handler, NULL);
502 signal_set(&sigtermev, SIGTERM, sig_handler, NULL);
503 signal_set(&sigchldev, SIGCHLD, sig_handler, NULL);
504 #ifdef SIGINFO
505 signal_set(&siginfoev, SIGINFO, sig_handler, NULL);
506 #else
507 signal_set(&siginfoev, SIGUSR1, sig_handler, NULL);
508 #endif
510 signal_add(&sighupev, NULL);
511 signal_add(&sigintev, NULL);
512 signal_add(&sigtermev, NULL);
513 signal_add(&sigchldev, NULL);
514 signal_add(&siginfoev, NULL);
516 for (i = 0; i < nsock; ++i) {
517 event_set(&sockev[i], socks[i], EV_READ|EV_PERSIST,
518 do_accept, NULL);
519 event_add(&sockev[i], NULL);
522 if (unveil(SSH_PROG, "x") == -1)
523 fatal("unveil(%s)", SSH_PROG);
525 /*
526 * dns, inet: bind the socket and connect to the childs.
527 * proc, exec: execute ssh on demand.
528 */
529 if (pledge("stdio dns inet proc exec", NULL) == -1)
530 fatal("pledge");
532 log_info("starting");
533 event_dispatch();
535 if (ssh_pid != -1)
536 kill(ssh_pid, SIGINT);
538 return 0;