Blob


1 /*
2 * Copyright (c) 2021 Omar Polo <op@omarpolo.com>
3 *
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15 */
17 #include <sys/types.h>
18 #include <sys/socket.h>
19 #include <sys/wait.h>
21 #include <ctype.h>
22 #include <errno.h>
23 #include <event.h>
24 #include <limits.h>
25 #include <netdb.h>
26 #include <signal.h>
27 #include <stdio.h>
28 #include <stdlib.h>
29 #include <string.h>
30 #include <syslog.h>
31 #include <unistd.h>
33 #include "log.h"
35 #ifndef SSH_PATH
36 #define SSH_PATH "/usr/bin/ssh"
37 #endif
39 #define MAXSOCK 32
40 #define BACKOFF 1
41 #define RETRIES 16
43 #ifndef __OpenBSD__
44 #define pledge(p, e) 0
45 #endif
47 const char *addr; /* our addr */
48 const char *ssh_tflag;
49 const char *ssh_dest;
51 char ssh_host[256];
52 char ssh_port[16];
54 struct event sockev[MAXSOCK];
55 int socks[MAXSOCK];
56 int nsock;
58 int debug;
59 int verbose;
61 struct event sighupev;
62 struct event sigintev;
63 struct event sigtermev;
64 struct event sigchldev;
65 struct event siginfoev;
67 struct timeval timeout = {600, 0}; /* 10 minutes */
68 struct event timeoutev;
70 pid_t ssh_pid = -1;
72 int conn;
74 struct conn {
75 int ntentative;
76 struct timeval retry;
77 struct event waitev;
78 int source;
79 struct bufferevent *sourcebev;
80 int to;
81 struct bufferevent *tobev;
82 };
84 static void
85 sig_handler(int sig, short event, void *data)
86 {
87 int status;
89 switch (sig) {
90 case SIGHUP:
91 case SIGINT:
92 case SIGTERM:
93 event_loopbreak();
94 break;
95 case SIGCHLD:
96 if (waitpid(ssh_pid, &status, WNOHANG) == -1)
97 fatal("waitpid");
98 ssh_pid = -1;
99 break;
100 #ifdef SIGINFO
101 case SIGINFO:
102 #else
103 case SIGUSR1:
104 #endif
105 log_info("connections: %d", conn);
109 static void
110 spawn_ssh(void)
112 log_debug("spawning ssh");
114 switch (ssh_pid = fork()) {
115 case -1:
116 fatal("fork");
117 case 0:
118 execl(SSH_PATH, "ssh", "-L", ssh_tflag, "-NTq", ssh_dest,
119 NULL);
120 fatal("exec");
121 default:
122 return;
126 static void
127 conn_free(struct conn *c)
129 if (c->sourcebev != NULL)
130 bufferevent_free(c->sourcebev);
131 if (c->tobev != NULL)
132 bufferevent_free(c->tobev);
134 if (evtimer_pending(&c->waitev, NULL))
135 evtimer_del(&c->waitev);
137 close(c->source);
138 close(c->to);
140 free(c);
143 static void
144 killing_time(int fd, short event, void *data)
146 if (ssh_pid == -1)
147 return;
149 log_debug("timeout expired, killing ssh (%d)", ssh_pid);
150 kill(ssh_pid, SIGTERM);
151 ssh_pid = -1;
154 static void
155 nopcb(struct bufferevent *bev, void *d)
157 return;
160 static void
161 sreadcb(struct bufferevent *bev, void *d)
163 struct conn *c = d;
165 bufferevent_write_buffer(c->tobev, EVBUFFER_INPUT(bev));
168 static void
169 treadcb(struct bufferevent *bev, void *d)
171 struct conn *c = d;
173 bufferevent_write_buffer(c->sourcebev, EVBUFFER_INPUT(bev));
176 static void
177 errcb(struct bufferevent *bev, short event, void *d)
179 struct conn *c = d;
181 log_info("closing connection (event=%x)", event);
183 conn_free(c);
185 if (--conn == 0) {
186 log_debug("scheduling ssh termination (%llds)",
187 (long long)timeout.tv_sec);
188 if (timeout.tv_sec != 0) {
189 evtimer_set(&timeoutev, killing_time, NULL);
190 evtimer_add(&timeoutev, &timeout);
195 static int
196 connect_to_ssh(void)
198 struct addrinfo hints, *res, *res0;
199 int r, saved_errno, sock;
200 const char *cause;
202 memset(&hints, 0, sizeof(hints));
203 hints.ai_family = AF_UNSPEC;
204 hints.ai_socktype = SOCK_STREAM;
206 r = getaddrinfo(ssh_host, ssh_port, &hints, &res0);
207 if (r != 0) {
208 log_warnx("getaddrinfo(\"%s\", \"%s\"): %s",
209 ssh_host, ssh_port, gai_strerror(r));
210 return -1;
213 for (res = res0; res; res = res->ai_next) {
214 sock = socket(res->ai_family, res->ai_socktype,
215 res->ai_protocol);
216 if (sock == -1) {
217 cause = "socket";
218 continue;
221 if (connect(sock, res->ai_addr, res->ai_addrlen) == -1) {
222 cause = "connect";
223 saved_errno = errno;
224 close(sock);
225 errno = saved_errno;
226 sock = -1;
227 continue;
230 break;
233 if (sock == -1)
234 log_warn("%s", cause);
236 freeaddrinfo(res0);
237 return sock;
240 static void
241 try_to_connect(int fd, short event, void *d)
243 struct conn *c = d;
245 /* ssh may die in the meantime */
246 if (ssh_pid == -1) {
247 close(c->source);
248 c->source = -1;
249 return;
252 c->ntentative++;
253 log_info("trying to connect to %s:%s (%d/%d)", ssh_host, ssh_port,
254 c->ntentative, RETRIES);
256 if ((c->to = connect_to_ssh()) == -1) {
257 if (c->ntentative == RETRIES) {
258 log_warnx("giving up connecting");
259 close(c->source);
260 c->source = -1;
261 return;
264 evtimer_set(&c->waitev, try_to_connect, c);
265 evtimer_add(&c->waitev, &c->retry);
266 return;
269 c->sourcebev = bufferevent_new(c->source, sreadcb, nopcb, errcb, c);
270 c->tobev = bufferevent_new(c->to, treadcb, nopcb, errcb, c);
271 if (c->sourcebev == NULL || c->tobev == NULL) {
272 log_warn("bufferevent_new");
273 conn_free(c);
274 return;
277 bufferevent_enable(c->sourcebev, EV_READ|EV_WRITE);
278 bufferevent_enable(c->tobev, EV_READ|EV_WRITE);
281 static void
282 do_accept(int fd, short event, void *data)
284 struct conn *c;
285 int s;
287 log_debug("incoming connection");
289 if (evtimer_pending(&timeoutev, NULL))
290 evtimer_del(&timeoutev);
292 if ((s = accept(fd, NULL, 0)) == -1)
293 fatal("accept");
295 conn++;
297 if (ssh_pid == -1)
298 spawn_ssh();
300 if ((c = calloc(1, sizeof(*c))) == NULL) {
301 log_warn("calloc");
302 close(s);
303 return;
306 c->source = s;
307 c->retry.tv_sec = BACKOFF;
308 evtimer_set(&c->waitev, try_to_connect, c);
309 evtimer_add(&c->waitev, &c->retry);
312 static const char *
313 copysec(const char *s, char *d, size_t len)
315 const char *c;
317 if ((c = strchr(s, ':')) == NULL)
318 return NULL;
319 if ((size_t)(c - s) >= len-1)
320 return NULL;
321 memset(d, 0, len);
322 memcpy(d, s, c - s);
323 return c;
326 static void
327 bind_socket(void)
329 struct addrinfo hints, *res, *res0;
330 int v, r, saved_errno;
331 char host[64];
332 const char *c, *h, *port, *cause;
334 if ((c = strchr(addr, ':')) == NULL) {
335 h = NULL;
336 port = addr;
337 } else {
338 if ((c = copysec(addr, host, sizeof(host))) == NULL)
339 fatalx("name too long: %s", addr);
341 h = host;
342 port = c+1;
345 memset(&hints, 0, sizeof(hints));
346 hints.ai_family = AF_UNSPEC;
347 hints.ai_socktype = SOCK_STREAM;
348 hints.ai_flags = AI_PASSIVE;
350 r = getaddrinfo(h, port, &hints, &res0);
351 if (r != 0)
352 fatalx("getaddrinfo(%s): %s", addr, gai_strerror(r));
354 for (res = res0; res && nsock < MAXSOCK; res = res->ai_next) {
355 socks[nsock] = socket(res->ai_family, res->ai_socktype,
356 res->ai_protocol);
357 if (socks[nsock] == -1) {
358 cause = "socket";
359 continue;
362 if (bind(socks[nsock], res->ai_addr, res->ai_addrlen) == -1) {
363 cause = "bind";
364 saved_errno = errno;
365 close(socks[nsock]);
366 errno = saved_errno;
367 continue;
370 v = 1;
371 if (setsockopt(socks[nsock], SOL_SOCKET, SO_REUSEADDR, &v,
372 sizeof(v)) == -1)
373 fatal("setsockopt(SO_REUSEADDR)");
375 v = 1;
376 if (setsockopt(socks[nsock], SOL_SOCKET, SO_REUSEPORT, &v,
377 sizeof(v)) == -1)
378 fatal("setsockopt(SO_REUSEPORT)");
380 listen(socks[nsock], 5);
382 nsock++;
384 if (nsock == 0)
385 fatal("%s", cause);
387 freeaddrinfo(res0);
390 static void
391 parse_tflag(void)
393 const char *c;
395 if (isdigit((unsigned char)*ssh_tflag)) {
396 strlcpy(ssh_host, "localhost", sizeof(ssh_host));
397 if (copysec(ssh_tflag, ssh_port, sizeof(ssh_port)) == NULL)
398 goto err;
399 return;
402 if ((c = copysec(ssh_tflag, ssh_host, sizeof(ssh_host))) == NULL)
403 goto err;
404 if (copysec(c+1, ssh_port, sizeof(ssh_port)) == NULL)
405 goto err;
406 return;
408 err:
409 fatal("wrong value for -B");
412 static void __dead
413 usage(void)
415 fprintf(stderr, "usage: %s [-dv] -B sshaddr -b addr [-t timeout]"
416 " destination\n", getprogname());
417 exit(1);
420 int
421 main(int argc, char **argv)
423 int ch, i;
424 const char *errstr;
426 log_init(1, LOG_DAEMON);
427 log_setverbose(1);
429 while ((ch = getopt(argc, argv, "B:b:dt:v")) != -1) {
430 switch (ch) {
431 case 'B':
432 ssh_tflag = optarg;
433 parse_tflag();
434 break;
435 case 'b':
436 addr = optarg;
437 break;
438 case 'd':
439 debug = 1;
440 break;
441 case 't':
442 timeout.tv_sec = strtonum(optarg, 0, INT_MAX, &errstr);
443 if (errstr != NULL)
444 fatalx("timeout is %s: %s", errstr, optarg);
445 break;
446 case 'v':
447 verbose = 1;
448 break;
449 default:
450 usage();
453 argc -= optind;
454 argv += optind;
456 if (argc != 1 || addr == NULL || ssh_tflag == NULL)
457 usage();
459 ssh_dest = argv[0];
461 log_init(debug, LOG_DAEMON);
462 log_setverbose(verbose);
464 if (!debug)
465 daemon(1, 0);
467 bind_socket();
469 signal(SIGPIPE, SIG_IGN);
471 event_init();
473 /* initialize the timer */
474 evtimer_set(&timeoutev, killing_time, NULL);
476 signal_set(&sighupev, SIGHUP, sig_handler, NULL);
477 signal_set(&sigintev, SIGINT, sig_handler, NULL);
478 signal_set(&sigtermev, SIGTERM, sig_handler, NULL);
479 signal_set(&sigchldev, SIGCHLD, sig_handler, NULL);
480 #ifdef SIGINFO
481 signal_set(&siginfoev, SIGINFO, sig_handler, NULL);
482 #else
483 signal_set(&siginfoev, SIGUSR1, sig_handler, NULL);
484 #endif
486 signal_add(&sighupev, NULL);
487 signal_add(&sigintev, NULL);
488 signal_add(&sigtermev, NULL);
489 signal_add(&sigchldev, NULL);
490 signal_add(&siginfoev, NULL);
492 for (i = 0; i < nsock; ++i) {
493 event_set(&sockev[i], socks[i], EV_READ|EV_PERSIST,
494 do_accept, NULL);
495 event_add(&sockev[i], NULL);
498 /*
499 * dns, inet: bind the socket and connect to the childs.
500 * proc, exec: execute ssh on demand.
501 */
502 if (pledge("stdio dns inet proc exec", NULL) == -1)
503 fatal("pledge");
505 log_info("starting");
506 event_dispatch();
508 if (ssh_pid != -1)
509 kill(ssh_pid, SIGINT);
511 return 0;