Blob


1 /*
2 * Copyright (c) 2021 Omar Polo <op@omarpolo.com>
3 *
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15 */
17 #include "compat.h"
19 #include <sys/types.h>
20 #include <sys/socket.h>
21 #include <sys/wait.h>
23 #include <ctype.h>
24 #include <errno.h>
25 #include <event.h>
26 #include <limits.h>
27 #include <netdb.h>
28 #include <signal.h>
29 #include <stdio.h>
30 #include <stdlib.h>
31 #include <string.h>
32 #include <syslog.h>
33 #include <unistd.h>
35 #include "log.h"
37 #ifndef SSH_PATH
38 #define SSH_PATH "ssh"
39 #endif
41 #define MAXSOCK 4
42 #define MAXCONN 16
43 #define BACKOFF 1
44 #define RETRIES 8
46 #ifndef __OpenBSD__
47 #define pledge(p, e) 0
48 #endif
50 const char *addr; /* our addr */
51 const char *ssh_tflag;
52 const char *ssh_dest;
54 char ssh_host[256];
55 char ssh_port[16];
57 struct event sockev[MAXSOCK];
58 int socks[MAXSOCK];
59 int nsock;
61 int debug;
62 int verbose;
64 struct event sighupev;
65 struct event sigintev;
66 struct event sigtermev;
67 struct event sigchldev;
68 struct event siginfoev;
70 struct timeval timeout = {120, 0};
71 struct event timeoutev;
73 pid_t ssh_pid = -1;
75 int conn;
77 struct conn {
78 int ntentative;
79 struct timeval retry;
80 struct event waitev;
81 int source;
82 struct bufferevent *sourcebev;
83 int to;
84 struct bufferevent *tobev;
85 } conns[MAXCONN];
87 static void
88 terminate(int fd, short event, void *data)
89 {
90 event_loopbreak();
91 }
93 static void
94 chld(int fd, short event, void *data)
95 {
96 int status;
98 if (waitpid(ssh_pid, &status, WNOHANG) == -1)
99 fatal("waitpid");
101 ssh_pid = -1;
104 static void
105 info(int fd, short event, void *data)
107 log_info("connections: %d", conn);
110 static void
111 spawn_ssh(void)
113 log_debug("spawning ssh");
115 switch (ssh_pid = fork()) {
116 case -1:
117 fatal("fork");
118 case 0:
119 execl(SSH_PATH, "ssh", "-L", ssh_tflag, "-NTq", ssh_dest,
120 NULL);
121 fatal("exec");
122 default:
123 return;
127 static void
128 killing_time(int fd, short event, void *data)
130 if (ssh_pid == -1)
131 return;
133 log_debug("timeout expired, killing ssh (%d)", ssh_pid);
134 kill(ssh_pid, SIGTERM);
135 ssh_pid = -1;
138 static void
139 nopcb(struct bufferevent *bev, void *d)
141 return;
144 static void
145 sreadcb(struct bufferevent *bev, void *d)
147 struct conn *c = d;
149 bufferevent_write_buffer(c->tobev, EVBUFFER_INPUT(bev));
152 static void
153 treadcb(struct bufferevent *bev, void *d)
155 struct conn *c = d;
157 bufferevent_write_buffer(c->sourcebev, EVBUFFER_INPUT(bev));
160 static void
161 errcb(struct bufferevent *bev, short event, void *d)
163 struct conn *c = d;
165 log_info("closing connection (event=%x)", event);
167 bufferevent_free(c->sourcebev);
168 bufferevent_free(c->tobev);
170 close(c->source);
171 close(c->to);
173 c->source = -1;
174 c->to = -1;
176 if (--conn == 0) {
177 log_debug("scheduling ssh termination (%llds)",
178 (long long)timeout.tv_sec);
179 if (timeout.tv_sec != 0) {
180 evtimer_set(&timeoutev, killing_time, NULL);
181 evtimer_add(&timeoutev, &timeout);
186 static int
187 connect_to_ssh(void)
189 struct addrinfo hints, *res, *res0;
190 int r, saved_errno, sock;
191 const char *cause;
193 memset(&hints, 0, sizeof(hints));
194 hints.ai_family = AF_UNSPEC;
195 hints.ai_socktype = SOCK_STREAM;
197 r = getaddrinfo(ssh_host, ssh_port, &hints, &res0);
198 if (r != 0)
199 fatal("getaddrinfo(\"%s\", \"%s\"): %s",
200 ssh_host, ssh_port, gai_strerror(r));
202 for (res = res0; res; res = res->ai_next) {
203 sock = socket(res->ai_family, res->ai_socktype,
204 res->ai_protocol);
205 if (sock == -1) {
206 cause = "socket";
207 continue;
210 if (connect(sock, res->ai_addr, res->ai_addrlen) == -1) {
211 cause = "connect";
212 saved_errno = errno;
213 close(sock);
214 errno = saved_errno;
215 sock = -1;
216 continue;
219 break;
222 if (sock == -1)
223 log_warn("%s", cause);
225 freeaddrinfo(res0);
226 return sock;
229 static void
230 try_to_connect(int fd, short event, void *d)
232 struct conn *c = d;
234 /* ssh may die in the meantime */
235 if (ssh_pid == -1) {
236 close(c->source);
237 c->source = -1;
238 return;
241 c->ntentative++;
242 log_debug("trying to connect to %s:%s (%d/%d)", ssh_host, ssh_port,
243 c->ntentative, RETRIES);
245 if ((c->to = connect_to_ssh()) == -1) {
246 if (c->ntentative == RETRIES) {
247 log_warnx("giving up connecting");
248 close(c->source);
249 c->source = -1;
250 return;
253 evtimer_set(&c->waitev, try_to_connect, c);
254 evtimer_add(&c->waitev, &c->retry);
255 return;
258 c->sourcebev = bufferevent_new(c->source, sreadcb, nopcb, errcb, c);
259 c->tobev = bufferevent_new(c->to, treadcb, nopcb, errcb, c);
260 if (c->sourcebev == NULL || c->tobev == NULL)
261 fatal("bufferevent_new");
262 bufferevent_enable(c->sourcebev, EV_READ|EV_WRITE);
263 bufferevent_enable(c->tobev, EV_READ|EV_WRITE);
266 static void
267 do_accept(int fd, short event, void *data)
269 int s, i;
271 log_debug("incoming connection");
273 if (evtimer_pending(&timeoutev, NULL))
274 evtimer_del(&timeoutev);
276 if ((s = accept(fd, NULL, 0)) == -1)
277 fatal("accept");
279 if (conn == MAXCONN) {
280 log_warnx("dropping the connection, too many already");
281 close(s);
282 return;
285 conn++;
287 if (ssh_pid == -1)
288 spawn_ssh();
290 for (i = 0; i < MAXCONN; ++i) {
291 if (conns[i].source != -1)
292 continue;
294 conns[i].source = s;
295 conns[i].ntentative = 0;
296 conns[i].retry.tv_sec = BACKOFF;
297 conns[i].retry.tv_usec = 0;
298 evtimer_set(&conns[i].waitev, try_to_connect, &conns[i]);
299 evtimer_add(&conns[i].waitev, &conns[i].retry);
300 break;
304 static const char *
305 copysec(const char *s, char *d, size_t len)
307 const char *c;
309 if ((c = strchr(s, ':')) == NULL)
310 return NULL;
311 if ((size_t)(c - s) >= len-1)
312 return NULL;
313 memset(d, 0, len);
314 memcpy(d, s, c - s);
315 return c;
318 static void
319 bind_socket(void)
321 struct addrinfo hints, *res, *res0;
322 int r, saved_errno;
323 char host[64];
324 const char *c, *h, *port, *cause;
326 if ((c = strchr(addr, ':')) == NULL) {
327 h = NULL;
328 port = addr;
329 } else {
330 if ((c = copysec(addr, host, sizeof(host))) == NULL)
331 fatalx("name too long: %s", addr);
333 h = host;
334 port = c+1;
337 memset(&hints, 0, sizeof(hints));
338 hints.ai_family = AF_UNSPEC;
339 hints.ai_socktype = SOCK_STREAM;
340 hints.ai_flags = AI_PASSIVE;
342 r = getaddrinfo(h, port, &hints, &res0);
343 if (r != 0)
344 fatalx("getaddrinfo(%s): %s", addr, gai_strerror(r));
346 for (res = res0; res && nsock < MAXSOCK; res = res->ai_next) {
347 socks[nsock] = socket(res->ai_family, res->ai_socktype,
348 res->ai_protocol);
349 if (socks[nsock] == -1) {
350 cause = "socket";
351 continue;
354 if (bind(socks[nsock], res->ai_addr, res->ai_addrlen) == -1) {
355 cause = "bind";
356 saved_errno = errno;
357 close(socks[nsock]);
358 errno = saved_errno;
359 continue;
362 listen(socks[nsock], 5);
364 nsock++;
366 if (nsock == 0)
367 fatal("%s", cause);
369 freeaddrinfo(res0);
372 static void
373 parse_tflag(void)
375 const char *c;
377 if (isdigit(*ssh_tflag)) {
378 strlcpy(ssh_host, "localhost", sizeof(ssh_host));
379 if (copysec(ssh_tflag, ssh_port, sizeof(ssh_port)) == NULL)
380 goto err;
381 return;
384 if ((c = copysec(ssh_tflag, ssh_host, sizeof(ssh_host))) == NULL)
385 goto err;
386 if (copysec(c+1, ssh_port, sizeof(ssh_port)) == NULL)
387 goto err;
388 return;
390 err:
391 fatal("wrong value for -B");
394 static void __dead
395 usage(void)
397 fprintf(stderr, "usage: %s [-dv] -B sshaddr -b addr [-t timeout]"
398 " destination\n", getprogname());
399 exit(1);
402 int
403 main(int argc, char **argv)
405 int ch, i;
406 const char *errstr;
408 log_init(1, LOG_DAEMON);
409 log_setverbose(1);
411 while ((ch = getopt(argc, argv, "B:b:dt:v")) != -1) {
412 switch (ch) {
413 case 'B':
414 ssh_tflag = optarg;
415 parse_tflag();
416 break;
417 case 'b':
418 addr = optarg;
419 break;
420 case 'd':
421 debug = 1;
422 break;
423 case 't':
424 timeout.tv_sec = strtonum(optarg, 0, INT_MAX, &errstr);
425 if (errstr != NULL)
426 fatalx("timeout is %s: %s", errstr, optarg);
427 break;
428 case 'v':
429 verbose = 1;
430 break;
431 default:
432 usage();
435 argc -= optind;
436 argv += optind;
438 if (argc != 1 || addr == NULL || ssh_tflag == NULL)
439 usage();
441 ssh_dest = argv[0];
443 for (i = 0; i < MAXCONN; ++i) {
444 conns[i].source = -1;
445 conns[i].to = -1;
448 log_init(debug, LOG_DAEMON);
449 log_setverbose(verbose);
451 if (!debug)
452 daemon(1, 0);
454 bind_socket();
456 signal(SIGPIPE, SIG_IGN);
458 event_init();
460 /* initialize the timer */
461 evtimer_set(&timeoutev, killing_time, NULL);
463 signal_set(&sighupev, SIGHUP, terminate, NULL);
464 signal_set(&sigintev, SIGINT, terminate, NULL);
465 signal_set(&sigtermev, SIGTERM, terminate, NULL);
466 signal_set(&sigchldev, SIGCHLD, chld, NULL);
467 signal_set(&siginfoev, SIGINFO, info, NULL);
469 signal_add(&sighupev, NULL);
470 signal_add(&sigintev, NULL);
471 signal_add(&sigtermev, NULL);
472 signal_add(&sigchldev, NULL);
473 signal_add(&siginfoev, NULL);
475 for (i = 0; i < nsock; ++i) {
476 event_set(&sockev[i], socks[i], EV_READ|EV_PERSIST,
477 do_accept, NULL);
478 event_add(&sockev[i], NULL);
481 /*
482 * dns, inet: bind the socket and connect to the childs.
483 * proc, exec: execute ssh on demand.
484 */
485 if (pledge("stdio dns inet proc exec", NULL) == -1)
486 fatal("pledge");
488 log_info("starting");
489 event_dispatch();
491 if (ssh_pid != -1)
492 kill(ssh_pid, SIGINT);
494 return 0;