Blob


1 /*
2 * Copyright (c) 2021, 2022 Omar Polo <op@omarpolo.com>
3 *
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15 */
17 #include <sys/types.h>
18 #include <sys/stat.h>
19 #include <sys/socket.h>
20 #include <sys/wait.h>
22 #include <ctype.h>
23 #include <errno.h>
24 #include <event.h>
25 #include <fcntl.h>
26 #include <limits.h>
27 #include <netdb.h>
28 #include <signal.h>
29 #include <stdio.h>
30 #include <stdlib.h>
31 #include <string.h>
32 #include <syslog.h>
33 #include <unistd.h>
35 #include "log.h"
37 #ifndef SSH_PATH
38 #define SSH_PATH "/usr/bin/ssh"
39 #endif
41 #define MAXSOCK 32
42 #define BACKOFF 1
43 #define RETRIES 16
45 #ifndef __OpenBSD__
46 #define pledge(p, e) 0
47 #endif
49 const char *addr; /* our addr */
50 const char *ssh_tflag;
51 const char *ssh_dest;
53 char ssh_host[256];
54 char ssh_port[16];
56 struct event sockev[MAXSOCK];
57 int socks[MAXSOCK];
58 int nsock;
60 int debug;
61 int verbose;
63 struct event sighupev;
64 struct event sigintev;
65 struct event sigtermev;
66 struct event sigchldev;
67 struct event siginfoev;
69 struct timeval timeout = {600, 0}; /* 10 minutes */
70 struct event timeoutev;
72 pid_t ssh_pid = -1;
74 int conn;
76 struct conn {
77 int ntentative;
78 struct timeval retry;
79 struct event waitev;
80 int source;
81 struct bufferevent *sourcebev;
82 int to;
83 struct bufferevent *tobev;
84 };
86 static void
87 sig_handler(int sig, short event, void *data)
88 {
89 int status;
91 switch (sig) {
92 case SIGHUP:
93 case SIGINT:
94 case SIGTERM:
95 log_info("quitting");
96 event_loopbreak();
97 break;
98 case SIGCHLD:
99 if (waitpid(ssh_pid, &status, WNOHANG) == -1)
100 fatal("waitpid");
101 ssh_pid = -1;
102 break;
103 #ifdef SIGINFO
104 case SIGINFO:
105 #else
106 case SIGUSR1:
107 #endif
108 log_info("connections: %d", conn);
112 static int
113 spawn_ssh(void)
115 log_debug("spawning ssh");
117 switch (ssh_pid = fork()) {
118 case -1:
119 log_warnx("fork");
120 return -1;
121 case 0:
122 execl(SSH_PATH, "ssh", "-L", ssh_tflag, "-NTq", ssh_dest,
123 NULL);
124 fatal("exec");
125 default:
126 return 0;
130 static void
131 conn_free(struct conn *c)
133 if (c->sourcebev != NULL)
134 bufferevent_free(c->sourcebev);
135 if (c->tobev != NULL)
136 bufferevent_free(c->tobev);
138 if (evtimer_pending(&c->waitev, NULL))
139 evtimer_del(&c->waitev);
141 close(c->source);
142 if (c->to != -1)
143 close(c->to);
145 free(c);
148 static void
149 killing_time(int fd, short event, void *data)
151 if (ssh_pid == -1)
152 return;
154 log_debug("timeout expired, killing ssh (%d)", ssh_pid);
155 kill(ssh_pid, SIGTERM);
156 ssh_pid = -1;
159 static void
160 nopcb(struct bufferevent *bev, void *d)
162 return;
165 static void
166 sreadcb(struct bufferevent *bev, void *d)
168 struct conn *c = d;
170 bufferevent_write_buffer(c->tobev, EVBUFFER_INPUT(bev));
173 static void
174 treadcb(struct bufferevent *bev, void *d)
176 struct conn *c = d;
178 bufferevent_write_buffer(c->sourcebev, EVBUFFER_INPUT(bev));
181 static void
182 errcb(struct bufferevent *bev, short event, void *d)
184 struct conn *c = d;
186 log_info("closing connection (event=%x)", event);
188 conn_free(c);
190 if (--conn == 0) {
191 log_debug("scheduling ssh termination (%llds)",
192 (long long)timeout.tv_sec);
193 if (timeout.tv_sec != 0) {
194 evtimer_set(&timeoutev, killing_time, NULL);
195 evtimer_add(&timeoutev, &timeout);
200 static int
201 connect_to_ssh(void)
203 struct addrinfo hints, *res, *res0;
204 int r, saved_errno, sock;
205 const char *cause;
207 memset(&hints, 0, sizeof(hints));
208 hints.ai_family = AF_UNSPEC;
209 hints.ai_socktype = SOCK_STREAM;
211 r = getaddrinfo(ssh_host, ssh_port, &hints, &res0);
212 if (r != 0) {
213 log_warnx("getaddrinfo(\"%s\", \"%s\"): %s",
214 ssh_host, ssh_port, gai_strerror(r));
215 return -1;
218 for (res = res0; res; res = res->ai_next) {
219 sock = socket(res->ai_family, res->ai_socktype,
220 res->ai_protocol);
221 if (sock == -1) {
222 cause = "socket";
223 continue;
226 if (connect(sock, res->ai_addr, res->ai_addrlen) == -1) {
227 cause = "connect";
228 saved_errno = errno;
229 close(sock);
230 errno = saved_errno;
231 sock = -1;
232 continue;
235 break;
238 if (sock == -1)
239 log_warn("%s", cause);
241 freeaddrinfo(res0);
242 return sock;
245 static void
246 try_to_connect(int fd, short event, void *d)
248 struct conn *c = d;
250 /* ssh may have died in the meantime */
251 if (ssh_pid == -1) {
252 conn_free(c);
253 return;
256 c->ntentative++;
257 log_info("trying to connect to %s:%s (%d/%d)", ssh_host, ssh_port,
258 c->ntentative, RETRIES);
260 if ((c->to = connect_to_ssh()) == -1) {
261 if (c->ntentative == RETRIES) {
262 log_warnx("giving up connecting");
263 conn_free(c);
264 return;
267 evtimer_set(&c->waitev, try_to_connect, c);
268 evtimer_add(&c->waitev, &c->retry);
269 return;
272 log_info("connected!");
274 c->sourcebev = bufferevent_new(c->source, sreadcb, nopcb, errcb, c);
275 c->tobev = bufferevent_new(c->to, treadcb, nopcb, errcb, c);
276 if (c->sourcebev == NULL || c->tobev == NULL) {
277 log_warn("bufferevent_new");
278 conn_free(c);
279 return;
282 bufferevent_enable(c->sourcebev, EV_READ|EV_WRITE);
283 bufferevent_enable(c->tobev, EV_READ|EV_WRITE);
286 static void
287 do_accept(int fd, short event, void *data)
289 struct conn *c;
290 int s;
292 log_debug("incoming connection");
294 if ((s = accept(fd, NULL, 0)) == -1) {
295 log_warn("accept");
296 return;
299 if (ssh_pid == -1 && spawn_ssh() == -1) {
300 close(s);
301 return;
304 if ((c = calloc(1, sizeof(*c))) == NULL) {
305 log_warn("calloc");
306 close(s);
307 return;
310 conn++;
311 if (evtimer_pending(&timeoutev, NULL))
312 evtimer_del(&timeoutev);
314 c->source = s;
315 c->to = -1;
316 c->retry.tv_sec = BACKOFF;
317 evtimer_set(&c->waitev, try_to_connect, c);
318 evtimer_add(&c->waitev, &c->retry);
321 static const char *
322 copysec(const char *s, char *d, size_t len)
324 const char *c;
326 if ((c = strchr(s, ':')) == NULL)
327 return NULL;
328 if ((size_t)(c - s) >= len-1)
329 return NULL;
330 memset(d, 0, len);
331 memcpy(d, s, c - s);
332 return c;
335 static void
336 bind_socket(void)
338 struct addrinfo hints, *res, *res0;
339 int v, r, saved_errno;
340 char host[64];
341 const char *c, *h, *port, *cause;
343 if ((c = strchr(addr, ':')) == NULL) {
344 h = NULL;
345 port = addr;
346 } else {
347 if ((c = copysec(addr, host, sizeof(host))) == NULL)
348 fatalx("name too long: %s", addr);
350 h = host;
351 port = c+1;
354 memset(&hints, 0, sizeof(hints));
355 hints.ai_family = AF_UNSPEC;
356 hints.ai_socktype = SOCK_STREAM;
357 hints.ai_flags = AI_PASSIVE;
359 r = getaddrinfo(h, port, &hints, &res0);
360 if (r != 0)
361 fatalx("getaddrinfo(%s): %s", addr, gai_strerror(r));
363 for (res = res0; res && nsock < MAXSOCK; res = res->ai_next) {
364 socks[nsock] = socket(res->ai_family, res->ai_socktype,
365 res->ai_protocol);
366 if (socks[nsock] == -1) {
367 cause = "socket";
368 continue;
371 if (bind(socks[nsock], res->ai_addr, res->ai_addrlen) == -1) {
372 cause = "bind";
373 saved_errno = errno;
374 close(socks[nsock]);
375 errno = saved_errno;
376 continue;
379 v = 1;
380 if (setsockopt(socks[nsock], SOL_SOCKET, SO_REUSEADDR, &v,
381 sizeof(v)) == -1)
382 fatal("setsockopt(SO_REUSEADDR)");
384 v = 1;
385 if (setsockopt(socks[nsock], SOL_SOCKET, SO_REUSEPORT, &v,
386 sizeof(v)) == -1)
387 fatal("setsockopt(SO_REUSEPORT)");
389 listen(socks[nsock], 5);
391 nsock++;
393 if (nsock == 0)
394 fatal("%s", cause);
396 freeaddrinfo(res0);
399 static void
400 parse_sshaddr(void)
402 const char *c;
404 if (isdigit((unsigned char)*ssh_tflag)) {
405 strlcpy(ssh_host, "localhost", sizeof(ssh_host));
406 if (copysec(ssh_tflag, ssh_port, sizeof(ssh_port)) == NULL)
407 goto err;
408 return;
411 if ((c = copysec(ssh_tflag, ssh_host, sizeof(ssh_host))) == NULL)
412 goto err;
413 if (copysec(c+1, ssh_port, sizeof(ssh_port)) == NULL)
414 goto err;
415 return;
417 err:
418 fatal("wrong value for -B");
421 static void __dead
422 usage(void)
424 fprintf(stderr, "usage: %s [-dv] -B sshaddr -b addr [-t timeout]"
425 " destination\n", getprogname());
426 exit(1);
429 int
430 main(int argc, char **argv)
432 int ch, i, fd;
433 const char *errstr;
434 struct stat sb;
436 /*
437 * Ensure we have fds 0-2 open so that we have no issue with
438 * calling bind_socket before daemon(3).
439 */
440 for (i = 0; i < 3; ++i) {
441 if (fstat(i, &sb) == -1) {
442 if ((fd = open("/dev/null", O_RDWR)) != -1) {
443 if (dup2(fd, i) == -1)
444 exit(1);
445 if (fd > i)
446 close(fd);
447 } else
448 exit(1);
452 log_init(1, LOG_DAEMON);
453 log_setverbose(1);
455 while ((ch = getopt(argc, argv, "B:b:dt:v")) != -1) {
456 switch (ch) {
457 case 'B':
458 ssh_tflag = optarg;
459 parse_sshaddr();
460 break;
461 case 'b':
462 addr = optarg;
463 break;
464 case 'd':
465 debug = 1;
466 break;
467 case 't':
468 timeout.tv_sec = strtonum(optarg, 0, INT_MAX, &errstr);
469 if (errstr != NULL)
470 fatalx("timeout is %s: %s", errstr, optarg);
471 break;
472 case 'v':
473 verbose = 1;
474 break;
475 default:
476 usage();
479 argc -= optind;
480 argv += optind;
482 if (argc != 1 || addr == NULL || ssh_tflag == NULL)
483 usage();
485 ssh_dest = argv[0];
487 bind_socket();
489 log_init(debug, LOG_DAEMON);
490 log_setverbose(verbose);
492 if (!debug)
493 daemon(1, 0);
495 signal(SIGPIPE, SIG_IGN);
497 event_init();
499 /* initialize the timer */
500 evtimer_set(&timeoutev, killing_time, NULL);
502 signal_set(&sighupev, SIGHUP, sig_handler, NULL);
503 signal_set(&sigintev, SIGINT, sig_handler, NULL);
504 signal_set(&sigtermev, SIGTERM, sig_handler, NULL);
505 signal_set(&sigchldev, SIGCHLD, sig_handler, NULL);
506 #ifdef SIGINFO
507 signal_set(&siginfoev, SIGINFO, sig_handler, NULL);
508 #else
509 signal_set(&siginfoev, SIGUSR1, sig_handler, NULL);
510 #endif
512 signal_add(&sighupev, NULL);
513 signal_add(&sigintev, NULL);
514 signal_add(&sigtermev, NULL);
515 signal_add(&sigchldev, NULL);
516 signal_add(&siginfoev, NULL);
518 for (i = 0; i < nsock; ++i) {
519 event_set(&sockev[i], socks[i], EV_READ|EV_PERSIST,
520 do_accept, NULL);
521 event_add(&sockev[i], NULL);
524 /*
525 * dns, inet: bind the socket and connect to the childs.
526 * proc, exec: execute ssh on demand.
527 */
528 if (pledge("stdio dns inet proc exec", NULL) == -1)
529 fatal("pledge");
531 log_info("starting");
532 event_dispatch();
534 if (ssh_pid != -1)
535 kill(ssh_pid, SIGINT);
537 return 0;