Blob


1 /*
2 * Copyright (c) 2021 Omar Polo <op@omarpolo.com>
3 *
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15 */
17 #include "compat.h"
19 #include <sys/types.h>
20 #include <sys/socket.h>
21 #include <sys/wait.h>
23 #include <ctype.h>
24 #include <errno.h>
25 #include <event.h>
26 #include <limits.h>
27 #include <netdb.h>
28 #include <signal.h>
29 #include <stdio.h>
30 #include <stdlib.h>
31 #include <string.h>
32 #include <unistd.h>
34 #ifndef SSH_PATH
35 #define SSH_PATH "ssh"
36 #endif
38 #define MAXSOCK 4
39 #define MAXCONN 16
40 #define BACKOFF 1
41 #define RETRIES 8
43 #ifndef __OpenBSD__
44 #define pledge(p, e) 0
45 #endif
47 const char *addr; /* our addr */
48 const char *ssh_tflag;
49 const char *ssh_dest;
51 char ssh_host[256];
52 char ssh_port[16];
54 struct event sockev[MAXSOCK];
55 int socks[MAXSOCK];
56 int nsock;
58 struct event sighupev;
59 struct event sigintev;
60 struct event sigtermev;
61 struct event sigchldev;
62 struct event siginfoev;
64 struct timeval timeout = {120, 0};
65 struct event timeoutev;
67 pid_t ssh_pid = -1;
69 int conn;
71 struct conn {
72 int ntentative;
73 struct timeval retry;
74 struct event waitev;
75 int source;
76 struct bufferevent *sourcebev;
77 int to;
78 struct bufferevent *tobev;
79 } conns[MAXCONN];
81 static void
82 terminate(int fd, short event, void *data)
83 {
84 event_loopbreak();
85 }
87 static void
88 chld(int fd, short event, void *data)
89 {
90 int status;
92 if (waitpid(ssh_pid, &status, WNOHANG) == -1)
93 err(1, "waitpid");
95 ssh_pid = -1;
96 }
98 static void
99 info(int fd, short event, void *data)
101 warnx("connections: %d", conn);
104 static void
105 spawn_ssh(void)
107 warnx("spawning ssh...");
109 switch (ssh_pid = fork()) {
110 case -1:
111 err(1, "fork");
112 case 0:
113 execl(SSH_PATH, "ssh", "-L", ssh_tflag, "-NTq", ssh_dest,
114 NULL);
115 err(1, "exec");
116 default:
117 return;
121 static void
122 killing_time(int fd, short event, void *data)
124 if (ssh_pid == -1)
125 return;
127 warnx("killing time!");
128 kill(ssh_pid, SIGTERM);
129 ssh_pid = -1;
132 static void
133 nopcb(struct bufferevent *bev, void *d)
135 return;
138 static void
139 sreadcb(struct bufferevent *bev, void *d)
141 struct conn *c = d;
143 bufferevent_write_buffer(c->tobev, EVBUFFER_INPUT(bev));
146 static void
147 treadcb(struct bufferevent *bev, void *d)
149 struct conn *c = d;
151 bufferevent_write_buffer(c->sourcebev, EVBUFFER_INPUT(bev));
154 static void
155 errcb(struct bufferevent *bev, short event, void *d)
157 struct conn *c = d;
159 warnx("in errcb, closing connection");
161 bufferevent_free(c->sourcebev);
162 bufferevent_free(c->tobev);
164 close(c->source);
165 close(c->to);
167 c->source = -1;
168 c->to = -1;
170 if (--conn == 0) {
171 warnx("scheduling ssh termination (%llds)",
172 (long long)timeout.tv_sec);
173 if (timeout.tv_sec != 0) {
174 evtimer_set(&timeoutev, killing_time, NULL);
175 evtimer_add(&timeoutev, &timeout);
180 static int
181 connect_to_ssh(void)
183 struct addrinfo hints, *res, *res0;
184 int r, saved_errno, sock;
185 const char *cause;
187 memset(&hints, 0, sizeof(hints));
188 hints.ai_family = AF_UNSPEC;
189 hints.ai_socktype = SOCK_STREAM;
191 r = getaddrinfo(ssh_host, ssh_port, &hints, &res0);
192 if (r != 0)
193 errx(1, "getaddrinfo(\"%s\", \"%s\"): %s",
194 ssh_host, ssh_port, gai_strerror(r));
196 for (res = res0; res; res = res->ai_next) {
197 sock = socket(res->ai_family, res->ai_socktype,
198 res->ai_protocol);
199 if (sock == -1) {
200 cause = "socket";
201 continue;
204 if (connect(sock, res->ai_addr, res->ai_addrlen) == -1) {
205 cause = "connect";
206 saved_errno = errno;
207 close(sock);
208 errno = saved_errno;
209 sock = -1;
210 continue;
213 break;
216 if (sock == -1)
217 warn("%s", cause);
219 freeaddrinfo(res0);
220 return sock;
223 static void
224 try_to_connect(int fd, short event, void *d)
226 struct conn *c = d;
228 /* ssh may die in the meantime */
229 if (ssh_pid == -1) {
230 close(c->source);
231 c->source = -1;
232 return;
235 c->ntentative++;
236 warnx("trying to connect to %s:%s (%d/%d)", ssh_host, ssh_port,
237 c->ntentative, RETRIES);
239 if ((c->to = connect_to_ssh()) == -1) {
240 if (c->ntentative == RETRIES) {
241 warnx("giving up");
242 close(c->source);
243 c->source = -1;
244 return;
247 evtimer_set(&c->waitev, try_to_connect, c);
248 evtimer_add(&c->waitev, &c->retry);
249 return;
252 c->sourcebev = bufferevent_new(c->source, sreadcb, nopcb, errcb, c);
253 c->tobev = bufferevent_new(c->to, treadcb, nopcb, errcb, c);
254 if (c->sourcebev == NULL || c->tobev == NULL)
255 err(1, "bufferevent_new");
256 bufferevent_enable(c->sourcebev, EV_READ|EV_WRITE);
257 bufferevent_enable(c->tobev, EV_READ|EV_WRITE);
260 static void
261 do_accept(int fd, short event, void *data)
263 int s, i;
265 warnx("handling connection");
267 if (evtimer_pending(&timeoutev, NULL))
268 evtimer_del(&timeoutev);
270 if ((s = accept(fd, NULL, 0)) == -1)
271 err(1, "accept");
273 if (conn == MAXCONN) {
274 /* oops */
275 close(s);
276 return;
279 conn++;
281 if (ssh_pid == -1)
282 spawn_ssh();
284 for (i = 0; i < MAXCONN; ++i) {
285 if (conns[i].source != -1)
286 continue;
288 conns[i].source = s;
289 conns[i].ntentative = 0;
290 conns[i].retry.tv_sec = BACKOFF;
291 conns[i].retry.tv_usec = 0;
292 evtimer_set(&conns[i].waitev, try_to_connect, &conns[i]);
293 evtimer_add(&conns[i].waitev, &conns[i].retry);
294 break;
298 static const char *
299 copysec(const char *s, char *d, size_t len)
301 const char *c;
303 if ((c = strchr(s, ':')) == NULL)
304 return NULL;
305 if ((size_t)(c - s) >= len-1)
306 return NULL;
307 memset(d, 0, len);
308 memcpy(d, s, c - s);
309 return c;
312 static void
313 bind_socket(void)
315 struct addrinfo hints, *res, *res0;
316 int r, saved_errno;
317 char host[64];
318 const char *c, *h, *port, *cause;
320 if ((c = strchr(addr, ':')) == NULL) {
321 h = NULL;
322 port = addr;
323 } else {
324 if ((c = copysec(addr, host, sizeof(host))) == NULL)
325 errx(1, "ENAMETOOLONG");
327 h = host;
328 port = c+1;
331 memset(&hints, 0, sizeof(hints));
332 hints.ai_family = AF_UNSPEC;
333 hints.ai_socktype = SOCK_STREAM;
334 hints.ai_flags = AI_PASSIVE;
336 r = getaddrinfo(h, port, &hints, &res0);
337 if (r != 0)
338 errx(1, "getaddrinfo(%s): %s",
339 addr, gai_strerror(r));
341 for (res = res0; res && nsock < MAXSOCK; res = res->ai_next) {
342 socks[nsock] = socket(res->ai_family, res->ai_socktype,
343 res->ai_protocol);
344 if (socks[nsock] == -1) {
345 cause = "socket";
346 continue;
349 if (bind(socks[nsock], res->ai_addr, res->ai_addrlen) == -1) {
350 cause = "bind";
351 saved_errno = errno;
352 close(socks[nsock]);
353 errno = saved_errno;
354 continue;
357 listen(socks[nsock], 5);
359 nsock++;
361 if (nsock == 0)
362 err(1, "%s", cause);
364 freeaddrinfo(res0);
367 static void
368 parse_tflag(void)
370 const char *c;
372 if (isdigit(*ssh_tflag)) {
373 strlcpy(ssh_host, "localhost", sizeof(ssh_host));
374 if (copysec(ssh_tflag, ssh_port, sizeof(ssh_port)) == NULL)
375 goto err;
376 return;
379 if ((c = copysec(ssh_tflag, ssh_host, sizeof(ssh_host))) == NULL)
380 goto err;
381 if (copysec(c+1, ssh_port, sizeof(ssh_port)) == NULL)
382 goto err;
383 return;
385 err:
386 errx(1, "wrong value for -B");
389 static void __dead
390 usage(void)
392 fprintf(stderr, "usage: %s -B sshaddr -b addr [-t timeout]"
393 " destination\n", getprogname());
394 exit(1);
397 int
398 main(int argc, char **argv)
400 int ch, i;
401 const char *errstr;
403 while ((ch = getopt(argc, argv, "B:b:t:")) != -1) {
404 switch (ch) {
405 case 'B':
406 ssh_tflag = optarg;
407 parse_tflag();
408 break;
409 case 'b':
410 addr = optarg;
411 break;
412 case 't':
413 timeout.tv_sec = strtonum(optarg, 0, INT_MAX, &errstr);
414 if (errstr != NULL)
415 errx(1, "timeout is %s: %s", errstr, optarg);
416 break;
417 default:
418 usage();
421 argc -= optind;
422 argv += optind;
424 if (argc != 1 || addr == NULL || ssh_tflag == NULL)
425 usage();
427 ssh_dest = argv[0];
429 for (i = 0; i < MAXCONN; ++i) {
430 conns[i].source = -1;
431 conns[i].to = -1;
434 bind_socket();
436 signal(SIGPIPE, SIG_IGN);
438 event_init();
440 /* initialize the timer */
441 evtimer_set(&timeoutev, killing_time, NULL);
443 signal_set(&sighupev, SIGHUP, terminate, NULL);
444 signal_set(&sigintev, SIGINT, terminate, NULL);
445 signal_set(&sigtermev, SIGTERM, terminate, NULL);
446 signal_set(&sigchldev, SIGCHLD, chld, NULL);
447 signal_set(&siginfoev, SIGINFO, info, NULL);
449 signal_add(&sighupev, NULL);
450 signal_add(&sigintev, NULL);
451 signal_add(&sigtermev, NULL);
452 signal_add(&sigchldev, NULL);
453 signal_add(&siginfoev, NULL);
455 for (i = 0; i < nsock; ++i) {
456 event_set(&sockev[i], socks[i], EV_READ|EV_PERSIST,
457 do_accept, NULL);
458 event_add(&sockev[i], NULL);
461 /*
462 * dns, inet: bind the socket and connect to the childs.
463 * proc, exec: execute ssh on demand.
464 */
465 if (pledge("stdio dns inet proc exec", NULL) == -1)
466 err(1, "pledge");
468 warnx("lift off!");
469 event_dispatch();
471 if (ssh_pid != -1)
472 kill(ssh_pid, SIGINT);
474 return 0;