Blob


1 /*
2 * Copyright (c) 2021 Omar Polo <op@omarpolo.com>
3 *
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15 */
17 #include <sys/types.h>
18 #include <sys/socket.h>
19 #include <sys/wait.h>
21 #include <ctype.h>
22 #include <errno.h>
23 #include <event.h>
24 #include <limits.h>
25 #include <netdb.h>
26 #include <signal.h>
27 #include <stdio.h>
28 #include <stdlib.h>
29 #include <string.h>
30 #include <syslog.h>
31 #include <unistd.h>
33 #include "log.h"
35 #ifndef SSH_PATH
36 #define SSH_PATH "/usr/bin/ssh"
37 #endif
39 #define MAXSOCK 32
40 #define MAXCONN 16
41 #define BACKOFF 1
42 #define RETRIES 8
44 #ifndef __OpenBSD__
45 #define pledge(p, e) 0
46 #endif
48 const char *addr; /* our addr */
49 const char *ssh_tflag;
50 const char *ssh_dest;
52 char ssh_host[256];
53 char ssh_port[16];
55 struct event sockev[MAXSOCK];
56 int socks[MAXSOCK];
57 int nsock;
59 int debug;
60 int verbose;
62 struct event sighupev;
63 struct event sigintev;
64 struct event sigtermev;
65 struct event sigchldev;
66 struct event siginfoev;
68 struct timeval timeout = {600, 0}; /* 10 minutes */
69 struct event timeoutev;
71 pid_t ssh_pid = -1;
73 int conn;
75 struct conn {
76 int ntentative;
77 struct timeval retry;
78 struct event waitev;
79 int source;
80 struct bufferevent *sourcebev;
81 int to;
82 struct bufferevent *tobev;
83 } conns[MAXCONN];
85 static void
86 terminate(int fd, short event, void *data)
87 {
88 event_loopbreak();
89 }
91 static void
92 chld(int fd, short event, void *data)
93 {
94 int status;
96 if (waitpid(ssh_pid, &status, WNOHANG) == -1)
97 fatal("waitpid");
99 ssh_pid = -1;
102 static void
103 info(int fd, short event, void *data)
105 log_info("connections: %d", conn);
108 static void
109 spawn_ssh(void)
111 log_debug("spawning ssh");
113 switch (ssh_pid = fork()) {
114 case -1:
115 fatal("fork");
116 case 0:
117 execl(SSH_PATH, "ssh", "-L", ssh_tflag, "-NTq", ssh_dest,
118 NULL);
119 fatal("exec");
120 default:
121 return;
125 static void
126 killing_time(int fd, short event, void *data)
128 if (ssh_pid == -1)
129 return;
131 log_debug("timeout expired, killing ssh (%d)", ssh_pid);
132 kill(ssh_pid, SIGTERM);
133 ssh_pid = -1;
136 static void
137 nopcb(struct bufferevent *bev, void *d)
139 return;
142 static void
143 sreadcb(struct bufferevent *bev, void *d)
145 struct conn *c = d;
147 bufferevent_write_buffer(c->tobev, EVBUFFER_INPUT(bev));
150 static void
151 treadcb(struct bufferevent *bev, void *d)
153 struct conn *c = d;
155 bufferevent_write_buffer(c->sourcebev, EVBUFFER_INPUT(bev));
158 static void
159 errcb(struct bufferevent *bev, short event, void *d)
161 struct conn *c = d;
163 log_info("closing connection (event=%x)", event);
165 bufferevent_free(c->sourcebev);
166 bufferevent_free(c->tobev);
168 close(c->source);
169 close(c->to);
171 c->source = -1;
172 c->to = -1;
174 if (--conn == 0) {
175 log_debug("scheduling ssh termination (%llds)",
176 (long long)timeout.tv_sec);
177 if (timeout.tv_sec != 0) {
178 evtimer_set(&timeoutev, killing_time, NULL);
179 evtimer_add(&timeoutev, &timeout);
184 static int
185 connect_to_ssh(void)
187 struct addrinfo hints, *res, *res0;
188 int r, saved_errno, sock;
189 const char *cause;
191 memset(&hints, 0, sizeof(hints));
192 hints.ai_family = AF_UNSPEC;
193 hints.ai_socktype = SOCK_STREAM;
195 r = getaddrinfo(ssh_host, ssh_port, &hints, &res0);
196 if (r != 0)
197 fatal("getaddrinfo(\"%s\", \"%s\"): %s",
198 ssh_host, ssh_port, gai_strerror(r));
200 for (res = res0; res; res = res->ai_next) {
201 sock = socket(res->ai_family, res->ai_socktype,
202 res->ai_protocol);
203 if (sock == -1) {
204 cause = "socket";
205 continue;
208 if (connect(sock, res->ai_addr, res->ai_addrlen) == -1) {
209 cause = "connect";
210 saved_errno = errno;
211 close(sock);
212 errno = saved_errno;
213 sock = -1;
214 continue;
217 break;
220 if (sock == -1)
221 log_warn("%s", cause);
223 freeaddrinfo(res0);
224 return sock;
227 static void
228 try_to_connect(int fd, short event, void *d)
230 struct conn *c = d;
232 /* ssh may die in the meantime */
233 if (ssh_pid == -1) {
234 close(c->source);
235 c->source = -1;
236 return;
239 c->ntentative++;
240 log_info("trying to connect to %s:%s (%d/%d)", ssh_host, ssh_port,
241 c->ntentative, RETRIES);
243 if ((c->to = connect_to_ssh()) == -1) {
244 if (c->ntentative == RETRIES) {
245 log_warnx("giving up connecting");
246 close(c->source);
247 c->source = -1;
248 return;
251 evtimer_set(&c->waitev, try_to_connect, c);
252 evtimer_add(&c->waitev, &c->retry);
253 return;
256 c->sourcebev = bufferevent_new(c->source, sreadcb, nopcb, errcb, c);
257 c->tobev = bufferevent_new(c->to, treadcb, nopcb, errcb, c);
258 if (c->sourcebev == NULL || c->tobev == NULL)
259 fatal("bufferevent_new");
260 bufferevent_enable(c->sourcebev, EV_READ|EV_WRITE);
261 bufferevent_enable(c->tobev, EV_READ|EV_WRITE);
264 static void
265 do_accept(int fd, short event, void *data)
267 int s, i;
269 log_debug("incoming connection");
271 if (evtimer_pending(&timeoutev, NULL))
272 evtimer_del(&timeoutev);
274 if ((s = accept(fd, NULL, 0)) == -1)
275 fatal("accept");
277 if (conn == MAXCONN) {
278 log_warnx("dropping the connection, too many already");
279 close(s);
280 return;
283 conn++;
285 if (ssh_pid == -1)
286 spawn_ssh();
288 for (i = 0; i < MAXCONN; ++i) {
289 if (conns[i].source != -1)
290 continue;
292 conns[i].source = s;
293 conns[i].ntentative = 0;
294 conns[i].retry.tv_sec = BACKOFF;
295 conns[i].retry.tv_usec = 0;
296 evtimer_set(&conns[i].waitev, try_to_connect, &conns[i]);
297 evtimer_add(&conns[i].waitev, &conns[i].retry);
298 break;
302 static const char *
303 copysec(const char *s, char *d, size_t len)
305 const char *c;
307 if ((c = strchr(s, ':')) == NULL)
308 return NULL;
309 if ((size_t)(c - s) >= len-1)
310 return NULL;
311 memset(d, 0, len);
312 memcpy(d, s, c - s);
313 return c;
316 static void
317 bind_socket(void)
319 struct addrinfo hints, *res, *res0;
320 int v, r, saved_errno;
321 char host[64];
322 const char *c, *h, *port, *cause;
324 if ((c = strchr(addr, ':')) == NULL) {
325 h = NULL;
326 port = addr;
327 } else {
328 if ((c = copysec(addr, host, sizeof(host))) == NULL)
329 fatalx("name too long: %s", addr);
331 h = host;
332 port = c+1;
335 memset(&hints, 0, sizeof(hints));
336 hints.ai_family = AF_UNSPEC;
337 hints.ai_socktype = SOCK_STREAM;
338 hints.ai_flags = AI_PASSIVE;
340 r = getaddrinfo(h, port, &hints, &res0);
341 if (r != 0)
342 fatalx("getaddrinfo(%s): %s", addr, gai_strerror(r));
344 for (res = res0; res && nsock < MAXSOCK; res = res->ai_next) {
345 socks[nsock] = socket(res->ai_family, res->ai_socktype,
346 res->ai_protocol);
347 if (socks[nsock] == -1) {
348 cause = "socket";
349 continue;
352 if (bind(socks[nsock], res->ai_addr, res->ai_addrlen) == -1) {
353 cause = "bind";
354 saved_errno = errno;
355 close(socks[nsock]);
356 errno = saved_errno;
357 continue;
360 v = 1;
361 if (setsockopt(socks[nsock], SOL_SOCKET, SO_REUSEADDR, &v,
362 sizeof(v)) == -1)
363 fatal("setsockopt(SO_REUSEADDR)");
365 v = 1;
366 if (setsockopt(socks[nsock], SOL_SOCKET, SO_REUSEPORT, &v,
367 sizeof(v)) == -1)
368 fatal("setsockopt(SO_REUSEPORT)");
370 listen(socks[nsock], 5);
372 nsock++;
374 if (nsock == 0)
375 fatal("%s", cause);
377 freeaddrinfo(res0);
380 static void
381 parse_tflag(void)
383 const char *c;
385 if (isdigit(*ssh_tflag)) {
386 strlcpy(ssh_host, "localhost", sizeof(ssh_host));
387 if (copysec(ssh_tflag, ssh_port, sizeof(ssh_port)) == NULL)
388 goto err;
389 return;
392 if ((c = copysec(ssh_tflag, ssh_host, sizeof(ssh_host))) == NULL)
393 goto err;
394 if (copysec(c+1, ssh_port, sizeof(ssh_port)) == NULL)
395 goto err;
396 return;
398 err:
399 fatal("wrong value for -B");
402 static void __dead
403 usage(void)
405 fprintf(stderr, "usage: %s [-dv] -B sshaddr -b addr [-t timeout]"
406 " destination\n", getprogname());
407 exit(1);
410 int
411 main(int argc, char **argv)
413 int ch, i;
414 const char *errstr;
416 log_init(1, LOG_DAEMON);
417 log_setverbose(1);
419 while ((ch = getopt(argc, argv, "B:b:dt:v")) != -1) {
420 switch (ch) {
421 case 'B':
422 ssh_tflag = optarg;
423 parse_tflag();
424 break;
425 case 'b':
426 addr = optarg;
427 break;
428 case 'd':
429 debug = 1;
430 break;
431 case 't':
432 timeout.tv_sec = strtonum(optarg, 0, INT_MAX, &errstr);
433 if (errstr != NULL)
434 fatalx("timeout is %s: %s", errstr, optarg);
435 break;
436 case 'v':
437 verbose = 1;
438 break;
439 default:
440 usage();
443 argc -= optind;
444 argv += optind;
446 if (argc != 1 || addr == NULL || ssh_tflag == NULL)
447 usage();
449 ssh_dest = argv[0];
451 for (i = 0; i < MAXCONN; ++i) {
452 conns[i].source = -1;
453 conns[i].to = -1;
456 log_init(debug, LOG_DAEMON);
457 log_setverbose(verbose);
459 if (!debug)
460 daemon(1, 0);
462 bind_socket();
464 signal(SIGPIPE, SIG_IGN);
466 event_init();
468 /* initialize the timer */
469 evtimer_set(&timeoutev, killing_time, NULL);
471 signal_set(&sighupev, SIGHUP, terminate, NULL);
472 signal_set(&sigintev, SIGINT, terminate, NULL);
473 signal_set(&sigtermev, SIGTERM, terminate, NULL);
474 signal_set(&sigchldev, SIGCHLD, chld, NULL);
475 #ifdef SIGINFO
476 signal_set(&siginfoev, SIGINFO, info, NULL);
477 #else
478 signal_set(&siginfoev, SIGUSR1, info, NULL);
479 #endif
481 signal_add(&sighupev, NULL);
482 signal_add(&sigintev, NULL);
483 signal_add(&sigtermev, NULL);
484 signal_add(&sigchldev, NULL);
485 signal_add(&siginfoev, NULL);
487 for (i = 0; i < nsock; ++i) {
488 event_set(&sockev[i], socks[i], EV_READ|EV_PERSIST,
489 do_accept, NULL);
490 event_add(&sockev[i], NULL);
493 /*
494 * dns, inet: bind the socket and connect to the childs.
495 * proc, exec: execute ssh on demand.
496 */
497 if (pledge("stdio dns inet proc exec", NULL) == -1)
498 fatal("pledge");
500 log_info("starting");
501 event_dispatch();
503 if (ssh_pid != -1)
504 kill(ssh_pid, SIGINT);
506 return 0;