Blob


1 /*
2 * Copyright (c) 2022 Stefan Sperling <stsp@openbsd.org>
3 *
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15 */
17 #include <sys/types.h>
18 #include <sys/queue.h>
19 #include <sys/socket.h>
20 #include <sys/uio.h>
22 #include <errno.h>
23 #include <event.h>
24 #include <siphash.h>
25 #include <stdint.h>
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <string.h>
29 #include <imsg.h>
30 #include <limits.h>
31 #include <sha1.h>
32 #include <sha2.h>
33 #include <signal.h>
34 #include <unistd.h>
36 #include "got_error.h"
38 #include "gotd.h"
39 #include "log.h"
40 #include "listen.h"
42 #ifndef nitems
43 #define nitems(_a) (sizeof((_a)) / sizeof((_a)[0]))
44 #endif
46 struct gotd_listen_client {
47 STAILQ_ENTRY(gotd_listen_client) entry;
48 uint32_t id;
49 int fd;
50 uid_t euid;
51 };
52 STAILQ_HEAD(gotd_listen_clients, gotd_listen_client);
54 static struct gotd_listen_clients gotd_listen_clients[GOTD_CLIENT_TABLE_SIZE];
55 static SIPHASH_KEY clients_hash_key;
56 static volatile int listen_client_cnt;
57 static int inflight;
59 struct gotd_uid_connection_counter {
60 STAILQ_ENTRY(gotd_uid_connection_counter) entry;
61 uid_t euid;
62 int nconnections;
63 };
64 STAILQ_HEAD(gotd_client_uids, gotd_uid_connection_counter);
65 static struct gotd_client_uids gotd_client_uids[GOTD_CLIENT_TABLE_SIZE];
66 static SIPHASH_KEY uid_hash_key;
68 static struct {
69 pid_t pid;
70 const char *title;
71 int fd;
72 struct gotd_imsgev iev;
73 struct gotd_imsgev pause;
74 struct gotd_uid_connection_limit *connection_limits;
75 size_t nconnection_limits;
76 } gotd_listen;
78 static int inflight;
80 static void listen_shutdown(void);
82 static void
83 listen_sighdlr(int sig, short event, void *arg)
84 {
85 /*
86 * Normal signal handler rules don't apply because libevent
87 * decouples for us.
88 */
90 switch (sig) {
91 case SIGHUP:
92 break;
93 case SIGUSR1:
94 break;
95 case SIGTERM:
96 case SIGINT:
97 listen_shutdown();
98 /* NOTREACHED */
99 break;
100 default:
101 fatalx("unexpected signal");
105 static uint64_t
106 client_hash(uint32_t client_id)
108 return SipHash24(&clients_hash_key, &client_id, sizeof(client_id));
111 static void
112 add_client(struct gotd_listen_client *client)
114 uint64_t slot = client_hash(client->id) % nitems(gotd_listen_clients);
115 STAILQ_INSERT_HEAD(&gotd_listen_clients[slot], client, entry);
116 listen_client_cnt++;
119 static struct gotd_listen_client *
120 find_client(uint32_t client_id)
122 uint64_t slot;
123 struct gotd_listen_client *c;
125 slot = client_hash(client_id) % nitems(gotd_listen_clients);
126 STAILQ_FOREACH(c, &gotd_listen_clients[slot], entry) {
127 if (c->id == client_id)
128 return c;
131 return NULL;
134 static uint32_t
135 get_client_id(void)
137 int duplicate = 0;
138 uint32_t id;
140 do {
141 id = arc4random();
142 duplicate = (find_client(id) != NULL);
143 } while (duplicate || id == 0);
145 return id;
148 static uint64_t
149 uid_hash(uid_t euid)
151 return SipHash24(&uid_hash_key, &euid, sizeof(euid));
154 static void
155 add_uid_connection_counter(struct gotd_uid_connection_counter *counter)
157 uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids);
158 STAILQ_INSERT_HEAD(&gotd_client_uids[slot], counter, entry);
161 static void
162 remove_uid_connection_counter(struct gotd_uid_connection_counter *counter)
164 uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids);
165 STAILQ_REMOVE(&gotd_client_uids[slot], counter,
166 gotd_uid_connection_counter, entry);
169 static struct gotd_uid_connection_counter *
170 find_uid_connection_counter(uid_t euid)
172 uint64_t slot;
173 struct gotd_uid_connection_counter *c;
175 slot = uid_hash(euid) % nitems(gotd_client_uids);
176 STAILQ_FOREACH(c, &gotd_client_uids[slot], entry) {
177 if (c->euid == euid)
178 return c;
181 return NULL;
184 struct gotd_uid_connection_limit *
185 gotd_find_uid_connection_limit(struct gotd_uid_connection_limit *limits,
186 size_t nlimits, uid_t uid)
188 /* This array is always sorted to allow for binary search. */
189 int i, left = 0, right = nlimits - 1;
191 while (left <= right) {
192 i = ((left + right) / 2);
193 if (limits[i].uid == uid)
194 return &limits[i];
195 if (limits[i].uid > uid)
196 left = i + 1;
197 else
198 right = i - 1;
201 return NULL;
204 static const struct got_error *
205 disconnect(struct gotd_listen_client *client)
207 struct gotd_uid_connection_counter *counter;
208 uint64_t slot;
209 int client_fd;
211 log_debug("client on fd %d disconnecting", client->fd);
213 slot = client_hash(client->id) % nitems(gotd_listen_clients);
214 STAILQ_REMOVE(&gotd_listen_clients[slot], client,
215 gotd_listen_client, entry);
217 counter = find_uid_connection_counter(client->euid);
218 if (counter) {
219 if (counter->nconnections > 0)
220 counter->nconnections--;
221 if (counter->nconnections == 0) {
222 remove_uid_connection_counter(counter);
223 free(counter);
227 client_fd = client->fd;
228 free(client);
229 inflight--;
230 listen_client_cnt--;
231 if (close(client_fd) == -1)
232 return got_error_from_errno("close");
234 return NULL;
237 static int
238 accept_reserve(int fd, struct sockaddr *addr, socklen_t *addrlen,
239 int reserve, volatile int *counter)
241 int ret;
243 if (getdtablecount() + reserve +
244 ((*counter + 1) * GOTD_FD_NEEDED) >= getdtablesize()) {
245 log_debug("inflight fds exceeded");
246 errno = EMFILE;
247 return -1;
250 if ((ret = accept4(fd, addr, addrlen,
251 SOCK_NONBLOCK | SOCK_CLOEXEC)) > -1) {
252 (*counter)++;
255 return ret;
258 static void
259 gotd_accept_paused(int fd, short event, void *arg)
261 event_add(&gotd_listen.iev.ev, NULL);
264 static void
265 gotd_accept(int fd, short event, void *arg)
267 struct gotd_imsgev *iev = arg;
268 struct sockaddr_storage ss;
269 struct timeval backoff;
270 socklen_t len;
271 int s = -1;
272 struct gotd_listen_client *client = NULL;
273 struct gotd_uid_connection_counter *counter = NULL;
274 struct gotd_imsg_connect iconn;
275 uid_t euid;
276 gid_t egid;
278 backoff.tv_sec = 1;
279 backoff.tv_usec = 0;
281 if (event_add(&gotd_listen.iev.ev, NULL) == -1) {
282 log_warn("event_add");
283 return;
285 if (event & EV_TIMEOUT)
286 return;
288 len = sizeof(ss);
290 /* Other backoff conditions apart from EMFILE/ENFILE? */
291 s = accept_reserve(fd, (struct sockaddr *)&ss, &len, GOTD_FD_RESERVE,
292 &inflight);
293 if (s == -1) {
294 switch (errno) {
295 case EINTR:
296 case EWOULDBLOCK:
297 case ECONNABORTED:
298 return;
299 case EMFILE:
300 case ENFILE:
301 event_del(&gotd_listen.iev.ev);
302 evtimer_add(&gotd_listen.pause.ev, &backoff);
303 return;
304 default:
305 log_warn("accept");
306 return;
310 if (listen_client_cnt >= GOTD_MAXCLIENTS)
311 goto err;
313 if (getpeereid(s, &euid, &egid) == -1) {
314 log_warn("getpeerid");
315 goto err;
318 counter = find_uid_connection_counter(euid);
319 if (counter == NULL) {
320 counter = calloc(1, sizeof(*counter));
321 if (counter == NULL) {
322 log_warn("%s: calloc", __func__);
323 goto err;
325 counter->euid = euid;
326 counter->nconnections = 1;
327 add_uid_connection_counter(counter);
328 } else {
329 int max_connections = GOTD_MAX_CONN_PER_UID;
330 struct gotd_uid_connection_limit *limit;
332 limit = gotd_find_uid_connection_limit(
333 gotd_listen.connection_limits,
334 gotd_listen.nconnection_limits, euid);
335 if (limit)
336 max_connections = limit->max_connections;
338 if (counter->nconnections >= max_connections) {
339 log_warnx("maximum connections exceeded for uid %d",
340 euid);
341 goto err;
343 counter->nconnections++;
346 client = calloc(1, sizeof(*client));
347 if (client == NULL) {
348 log_warn("%s: calloc", __func__);
349 goto err;
351 client->id = get_client_id();
352 client->fd = s;
353 client->euid = euid;
354 s = -1;
355 add_client(client);
356 log_debug("%s: new client connected on fd %d uid %d gid %d", __func__,
357 client->fd, euid, egid);
359 memset(&iconn, 0, sizeof(iconn));
360 iconn.client_id = client->id;
361 iconn.euid = euid;
362 iconn.egid = egid;
363 s = dup(client->fd);
364 if (s == -1) {
365 log_warn("%s: dup", __func__);
366 goto err;
368 if (gotd_imsg_compose_event(iev, GOTD_IMSG_CONNECT, PROC_LISTEN, s,
369 &iconn, sizeof(iconn)) == -1) {
370 log_warn("imsg compose CONNECT");
371 goto err;
374 return;
375 err:
376 inflight--;
377 if (client)
378 disconnect(client);
379 if (s != -1)
380 close(s);
383 static const struct got_error *
384 recv_disconnect(struct imsg *imsg)
386 struct gotd_imsg_disconnect idisconnect;
387 size_t datalen;
388 struct gotd_listen_client *client = NULL;
390 datalen = imsg->hdr.len - IMSG_HEADER_SIZE;
391 if (datalen != sizeof(idisconnect))
392 return got_error(GOT_ERR_PRIVSEP_LEN);
393 memcpy(&idisconnect, imsg->data, sizeof(idisconnect));
395 log_debug("client disconnecting");
397 client = find_client(idisconnect.client_id);
398 if (client == NULL)
399 return got_error(GOT_ERR_CLIENT_ID);
401 return disconnect(client);
404 static void
405 listen_dispatch(int fd, short event, void *arg)
407 const struct got_error *err = NULL;
408 struct gotd_imsgev *iev = arg;
409 struct imsgbuf *ibuf = &iev->ibuf;
410 struct imsg imsg;
411 ssize_t n;
412 int shut = 0;
414 if (event & EV_READ) {
415 if ((n = imsg_read(ibuf)) == -1 && errno != EAGAIN)
416 fatal("imsg_read error");
417 if (n == 0) /* Connection closed. */
418 shut = 1;
421 if (event & EV_WRITE) {
422 n = msgbuf_write(&ibuf->w);
423 if (n == -1 && errno != EAGAIN)
424 fatal("msgbuf_write");
425 if (n == 0) /* Connection closed. */
426 shut = 1;
429 for (;;) {
430 if ((n = imsg_get(ibuf, &imsg)) == -1)
431 fatal("%s: imsg_get", __func__);
432 if (n == 0) /* No more messages. */
433 break;
435 switch (imsg.hdr.type) {
436 case GOTD_IMSG_DISCONNECT:
437 err = recv_disconnect(&imsg);
438 if (err)
439 log_warnx("%s: disconnect: %s",
440 gotd_listen.title, err->msg);
441 break;
442 default:
443 log_debug("%s: unexpected imsg %d", gotd_listen.title,
444 imsg.hdr.type);
445 break;
448 imsg_free(&imsg);
451 if (!shut) {
452 gotd_imsg_event_add(iev);
453 } else {
454 /* This pipe is dead. Remove its event handler */
455 event_del(&iev->ev);
456 event_loopexit(NULL);
460 void
461 listen_main(const char *title, int gotd_socket,
462 struct gotd_uid_connection_limit *connection_limits,
463 size_t nconnection_limits)
465 struct gotd_imsgev iev;
466 struct event evsigint, evsigterm, evsighup, evsigusr1;
468 arc4random_buf(&clients_hash_key, sizeof(clients_hash_key));
469 arc4random_buf(&uid_hash_key, sizeof(uid_hash_key));
471 gotd_listen.title = title;
472 gotd_listen.pid = getpid();
473 gotd_listen.fd = gotd_socket;
474 gotd_listen.connection_limits = connection_limits;
475 gotd_listen.nconnection_limits = nconnection_limits;
477 signal_set(&evsigint, SIGINT, listen_sighdlr, NULL);
478 signal_set(&evsigterm, SIGTERM, listen_sighdlr, NULL);
479 signal_set(&evsighup, SIGHUP, listen_sighdlr, NULL);
480 signal_set(&evsigusr1, SIGUSR1, listen_sighdlr, NULL);
481 signal(SIGPIPE, SIG_IGN);
483 signal_add(&evsigint, NULL);
484 signal_add(&evsigterm, NULL);
485 signal_add(&evsighup, NULL);
486 signal_add(&evsigusr1, NULL);
488 imsg_init(&iev.ibuf, GOTD_FILENO_MSG_PIPE);
489 iev.handler = listen_dispatch;
490 iev.events = EV_READ;
491 iev.handler_arg = NULL;
492 event_set(&iev.ev, iev.ibuf.fd, EV_READ, listen_dispatch, &iev);
493 if (event_add(&iev.ev, NULL) == -1)
494 fatalx("event add");
496 event_set(&gotd_listen.iev.ev, gotd_listen.fd, EV_READ | EV_PERSIST,
497 gotd_accept, &iev);
498 if (event_add(&gotd_listen.iev.ev, NULL))
499 fatalx("event add");
500 evtimer_set(&gotd_listen.pause.ev, gotd_accept_paused, NULL);
502 event_dispatch();
504 listen_shutdown();
507 static void
508 listen_shutdown(void)
510 log_debug("%s: shutting down", gotd_listen.title);
512 free(gotd_listen.connection_limits);
513 if (gotd_listen.fd != -1)
514 close(gotd_listen.fd);
516 exit(0);