Blob


1 /*
2 * Copyright (c) 2021 Omar Polo <op@omarpolo.com>
3 *
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15 */
17 #include <sys/types.h>
18 #include <sys/socket.h>
19 #include <sys/wait.h>
21 #include <ctype.h>
22 #include <err.h>
23 #include <errno.h>
24 #include <event.h>
25 #include <limits.h>
26 #include <netdb.h>
27 #include <signal.h>
28 #include <stdio.h>
29 #include <stdlib.h>
30 #include <string.h>
31 #include <unistd.h>
33 #include <stdint.h>
35 #ifndef SSH_PATH
36 #define SSH_PATH "/usr/bin/ssh"
37 #endif
39 #define MAXSOCK 4
40 #define MAXCONN 16
41 #define BACKOFF 1
42 #define RETRIES 8
44 #ifndef __OpenBSD__
45 #define pledge(p, e) 0
46 #endif
48 int rport; /* ssh port */
49 const char *addr; /* our addr */
50 const char *ssh_tflag;
51 const char *ssh_dest;
53 char ssh_host[256];
54 char ssh_port[16];
56 struct event sockev[MAXSOCK];
57 int socks[MAXSOCK];
58 int nsock;
60 struct event sighupev;
61 struct event sigintev;
62 struct event sigtermev;
63 struct event sigchldev;
64 struct event siginfoev;
66 struct timeval timeout;
67 struct event timeoutev;
69 pid_t ssh_pid = -1;
71 int conn;
73 struct conn {
74 int ntentative;
75 struct timeval retry;
76 struct event waitev;
77 int source;
78 struct bufferevent *sourcebev;
79 int to;
80 struct bufferevent *tobev;
81 } conns[MAXCONN];
83 static void
84 terminate(int fd, short event, void *data)
85 {
86 event_loopbreak();
87 }
89 static void
90 chld(int fd, short event, void *data)
91 {
92 int status;
93 pid_t pid;
95 if ((pid = waitpid(ssh_pid, &status, WNOHANG)) == -1)
96 err(1, "waitpid");
98 ssh_pid = -1;
99 }
101 static void
102 info(int fd, short event, void *data)
104 warnx("connections: %d", conn);
107 static void
108 spawn_ssh(void)
110 warnx("spawning ssh...");
112 switch (ssh_pid = fork()) {
113 case -1:
114 err(1, "fork");
115 case 0:
116 execl(SSH_PATH, "ssh", "-L", ssh_tflag, "-NTq", ssh_dest,
117 NULL);
118 err(1, "exec");
119 default:
120 return;
124 static void
125 killing_time(int fd, short event, void *data)
127 if (ssh_pid == -1)
128 return;
130 warnx("killing time!");
131 kill(ssh_pid, SIGTERM);
132 ssh_pid = -1;
135 static void
136 nopcb(struct bufferevent *bev, void *d)
138 return;
141 static void
142 sreadcb(struct bufferevent *bev, void *d)
144 struct conn *c = d;
146 bufferevent_write_buffer(c->tobev, EVBUFFER_INPUT(bev));
149 static void
150 treadcb(struct bufferevent *bev, void *d)
152 struct conn *c = d;
154 bufferevent_write_buffer(c->sourcebev, EVBUFFER_INPUT(bev));
157 static void
158 errcb(struct bufferevent *bev, short event, void *d)
160 struct conn *c = d;
162 warnx("in errcb, closing connection");
164 bufferevent_free(c->sourcebev);
165 bufferevent_free(c->tobev);
167 close(c->source);
168 close(c->to);
170 c->source = -1;
171 c->to = -1;
173 if (--conn == 0) {
174 warnx("scheduling ssh termination (%llds)",
175 (long long)timeout.tv_sec);
176 evtimer_set(&timeoutev, killing_time, NULL);
177 evtimer_add(&timeoutev, &timeout);
181 static int
182 connect_to_ssh(void)
184 struct addrinfo hints, *res, *res0;
185 int r, saved_errno, sock;
186 const char *cause;
188 memset(&hints, 0, sizeof(hints));
189 hints.ai_family = AF_UNSPEC;
190 hints.ai_socktype = SOCK_STREAM;
192 r = getaddrinfo(ssh_host, ssh_port, &hints, &res0);
193 if (r != 0)
194 errx(1, "getaddrinfo(\"%s\", \"%s\"): %s",
195 ssh_host, ssh_port, gai_strerror(r));
197 for (res = res0; res; res = res->ai_next) {
198 sock = socket(res->ai_family, res->ai_socktype,
199 res->ai_protocol);
200 if (sock == -1) {
201 cause = "socket";
202 continue;
205 if (connect(sock, res->ai_addr, res->ai_addrlen) == -1) {
206 cause = "connect";
207 saved_errno = errno;
208 close(sock);
209 errno = saved_errno;
210 sock = -1;
211 continue;
214 break;
217 if (sock == -1)
218 warnx("%s", cause);
220 freeaddrinfo(res0);
221 return sock;
224 static void
225 try_to_connect(int fd, short event, void *d)
227 struct conn *c = d;
229 /* ssh may die in the meantime */
230 if (ssh_pid == -1) {
231 close(c->source);
232 c->source = -1;
233 return;
236 c->ntentative++;
237 warnx("trying to connect to %s:%s (%d/%d)", ssh_host, ssh_port,
238 c->ntentative, RETRIES);
240 if ((c->to = connect_to_ssh()) == -1) {
241 if (c->ntentative == RETRIES) {
242 warnx("giving up");
243 close(c->source);
244 c->source = -1;
245 return;
248 evtimer_set(&c->waitev, try_to_connect, c);
249 evtimer_add(&c->waitev, &c->retry);
250 return;
253 c->sourcebev = bufferevent_new(c->source, sreadcb, nopcb, errcb, c);
254 c->tobev = bufferevent_new(c->to, treadcb, nopcb, errcb, c);
255 if (c->sourcebev == NULL || c->tobev == NULL)
256 err(1, "bufferevent_new");
257 bufferevent_enable(c->sourcebev, EV_READ|EV_WRITE);
258 bufferevent_enable(c->tobev, EV_READ|EV_WRITE);
261 static void
262 do_accept(int fd, short event, void *data)
264 int s, i;
266 warnx("handling connection");
268 if (evtimer_pending(&timeoutev, NULL))
269 evtimer_del(&timeoutev);
271 if ((s = accept(fd, NULL, 0)) == -1)
272 err(1, "accept");
274 if (conn == MAXCONN) {
275 /* oops */
276 close(s);
277 return;
280 conn++;
282 if (ssh_pid == -1)
283 spawn_ssh();
285 for (i = 0; i < MAXCONN; ++i) {
286 if (conns[i].source != -1)
287 continue;
289 conns[i].source = s;
290 conns[i].ntentative = 0;
291 conns[i].retry.tv_sec = BACKOFF;
292 conns[i].retry.tv_usec = 0;
293 evtimer_set(&conns[i].waitev, try_to_connect, &conns[i]);
294 evtimer_add(&conns[i].waitev, &conns[i].retry);
295 break;
299 static void __dead
300 usage(void)
302 fprintf(stderr, "usage: %s -B sshaddr -b addr [-t timeout]"
303 " destination\n", getprogname());
304 exit(1);
307 static const char *
308 copysec(const char *s, char *d, size_t len)
310 const char *c;
312 if ((c = strchr(s, ':')) == NULL)
313 return NULL;
314 if ((size_t)(c - s) >= len-1)
315 return NULL;
316 memset(d, 0, len);
317 memcpy(d, s, c - s);
318 return c;
321 static void
322 bind_socket(void)
324 struct addrinfo hints, *res, *res0;
325 int r, saved_errno;
326 char host[64];
327 const char *c, *h, *port, *cause;
329 if ((c = strchr(addr, ':')) == NULL) {
330 h = NULL;
331 port = addr;
332 } else {
333 if ((c = copysec(addr, host, sizeof(host))) == NULL)
334 errx(1, "ENAMETOOLONG");
336 h = host;
337 port = c+1;
340 memset(&hints, 0, sizeof(hints));
341 hints.ai_family = AF_UNSPEC;
342 hints.ai_socktype = SOCK_STREAM;
343 hints.ai_flags = AI_PASSIVE;
345 r = getaddrinfo(h, port, &hints, &res0);
346 if (r != 0)
347 errx(1, "getaddrinfo(%s): %s",
348 addr, gai_strerror(r));
350 for (res = res0; res && nsock < MAXSOCK; res = res->ai_next) {
351 socks[nsock] = socket(res->ai_family, res->ai_socktype,
352 res->ai_protocol);
353 if (socks[nsock] == -1) {
354 cause = "socket";
355 continue;
358 if (bind(socks[nsock], res->ai_addr, res->ai_addrlen) == -1) {
359 cause = "bind";
360 saved_errno = errno;
361 close(socks[nsock]);
362 errno = saved_errno;
363 continue;
366 listen(socks[nsock], 5);
368 nsock++;
370 if (nsock == 0)
371 err(1, "%s", cause);
373 freeaddrinfo(res0);
376 static void
377 parse_tflag(void)
379 const char *c;
381 if (isdigit(*ssh_tflag)) {
382 strlcpy(ssh_host, "localhost", sizeof(ssh_host));
383 if (copysec(ssh_tflag, ssh_port, sizeof(ssh_port)) == NULL)
384 goto err;
385 return;
388 if ((c = copysec(ssh_tflag, ssh_host, sizeof(ssh_host))) == NULL)
389 goto err;
390 if (copysec(c+1, ssh_port, sizeof(ssh_port)) == NULL)
391 goto err;
392 return;
394 err:
395 errx(1, "wrong value for -B");
398 int
399 main(int argc, char **argv)
401 int ch, tout, i;
402 const char *errstr;
404 while ((ch = getopt(argc, argv, "B:b:t:")) != -1) {
405 switch (ch) {
406 case 'B':
407 ssh_tflag = optarg;
408 parse_tflag();
409 break;
410 case 'b':
411 addr = optarg;
412 break;
413 case 't':
414 tout = strtonum(optarg, 1, INT_MAX, &errstr);
415 if (errstr != NULL)
416 errx(1, "timeout is %s: %s", errstr, optarg);
417 break;
418 default:
419 usage();
422 argc -= optind;
423 argv += optind;
425 if (argc != 1 || addr == NULL || ssh_tflag == NULL)
426 usage();
428 if (tout == 0)
429 tout = 120;
431 timeout.tv_sec = tout;
432 timeout.tv_usec = 0;
434 ssh_dest = argv[0];
436 for (i = 0; i < MAXCONN; ++i) {
437 conns[i].source = -1;
438 conns[i].to = -1;
441 bind_socket();
443 signal(SIGPIPE, SIG_IGN);
445 event_init();
447 /* initialize the timer */
448 evtimer_set(&timeoutev, killing_time, NULL);
450 signal_set(&sighupev, SIGHUP, terminate, NULL);
451 signal_set(&sigintev, SIGINT, terminate, NULL);
452 signal_set(&sigtermev, SIGTERM, terminate, NULL);
453 signal_set(&sigchldev, SIGCHLD, chld, NULL);
454 signal_set(&siginfoev, SIGINFO, info, NULL);
456 signal_add(&sighupev, NULL);
457 signal_add(&sigintev, NULL);
458 signal_add(&sigtermev, NULL);
459 signal_add(&sigchldev, NULL);
460 signal_add(&siginfoev, NULL);
462 for (i = 0; i < nsock; ++i) {
463 event_set(&sockev[i], socks[i], EV_READ|EV_PERSIST,
464 do_accept, NULL);
465 event_add(&sockev[i], NULL);
468 /*
469 * dns, inet: bind the socket and connect to the childs.
470 * proc, exec: execute ssh on demand.
471 */
472 if (pledge("stdio dns inet proc exec", NULL) == -1)
473 err(1, "pledge");
475 warnx("lift off!");
476 event_dispatch();
478 if (ssh_pid != -1)
479 kill(ssh_pid, SIGINT);
481 return 0;