Blob


1 /*
2 * Copyright (c) 2022 Stefan Sperling <stsp@openbsd.org>
3 *
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15 */
17 #include <sys/types.h>
18 #include <sys/queue.h>
19 #include <sys/socket.h>
20 #include <sys/uio.h>
22 #include <errno.h>
23 #include <event.h>
24 #include <siphash.h>
25 #include <stdint.h>
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <string.h>
29 #include <imsg.h>
30 #include <limits.h>
31 #include <sha1.h>
32 #include <signal.h>
33 #include <unistd.h>
35 #include "got_error.h"
37 #include "gotd.h"
38 #include "log.h"
39 #include "listen.h"
41 #ifndef nitems
42 #define nitems(_a) (sizeof((_a)) / sizeof((_a)[0]))
43 #endif
45 struct gotd_listen_client {
46 STAILQ_ENTRY(gotd_listen_client) entry;
47 uint32_t id;
48 int fd;
49 uid_t euid;
50 };
51 STAILQ_HEAD(gotd_listen_clients, gotd_listen_client);
53 static struct gotd_listen_clients gotd_listen_clients[GOTD_CLIENT_TABLE_SIZE];
54 static SIPHASH_KEY clients_hash_key;
55 static volatile int listen_client_cnt;
56 static int inflight;
58 struct gotd_uid_connection_counter {
59 STAILQ_ENTRY(gotd_uid_connection_counter) entry;
60 uid_t euid;
61 int nconnections;
62 };
63 STAILQ_HEAD(gotd_client_uids, gotd_uid_connection_counter);
64 static struct gotd_client_uids gotd_client_uids[GOTD_CLIENT_TABLE_SIZE];
65 static SIPHASH_KEY uid_hash_key;
67 static struct {
68 pid_t pid;
69 const char *title;
70 int fd;
71 struct gotd_imsgev iev;
72 struct gotd_imsgev pause;
73 } gotd_listen;
75 static int inflight;
77 static void listen_shutdown(void);
79 static void
80 listen_sighdlr(int sig, short event, void *arg)
81 {
82 /*
83 * Normal signal handler rules don't apply because libevent
84 * decouples for us.
85 */
87 switch (sig) {
88 case SIGHUP:
89 break;
90 case SIGUSR1:
91 break;
92 case SIGTERM:
93 case SIGINT:
94 listen_shutdown();
95 /* NOTREACHED */
96 break;
97 default:
98 fatalx("unexpected signal");
99 }
102 static uint64_t
103 client_hash(uint32_t client_id)
105 return SipHash24(&clients_hash_key, &client_id, sizeof(client_id));
108 static void
109 add_client(struct gotd_listen_client *client)
111 uint64_t slot = client_hash(client->id) % nitems(gotd_listen_clients);
112 STAILQ_INSERT_HEAD(&gotd_listen_clients[slot], client, entry);
113 listen_client_cnt++;
116 static struct gotd_listen_client *
117 find_client(uint32_t client_id)
119 uint64_t slot;
120 struct gotd_listen_client *c;
122 slot = client_hash(client_id) % nitems(gotd_listen_clients);
123 STAILQ_FOREACH(c, &gotd_listen_clients[slot], entry) {
124 if (c->id == client_id)
125 return c;
128 return NULL;
131 static uint32_t
132 get_client_id(void)
134 int duplicate = 0;
135 uint32_t id;
137 do {
138 id = arc4random();
139 duplicate = (find_client(id) != NULL);
140 } while (duplicate || id == 0);
142 return id;
145 static uint64_t
146 uid_hash(uid_t euid)
148 return SipHash24(&uid_hash_key, &euid, sizeof(euid));
151 static void
152 add_uid_connection_counter(struct gotd_uid_connection_counter *counter)
154 uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids);
155 STAILQ_INSERT_HEAD(&gotd_client_uids[slot], counter, entry);
158 static void
159 remove_uid_connection_counter(struct gotd_uid_connection_counter *counter)
161 uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids);
162 STAILQ_REMOVE(&gotd_client_uids[slot], counter,
163 gotd_uid_connection_counter, entry);
166 static struct gotd_uid_connection_counter *
167 find_uid_connection_counter(uid_t euid)
169 uint64_t slot;
170 struct gotd_uid_connection_counter *c;
172 slot = uid_hash(euid) % nitems(gotd_client_uids);
173 STAILQ_FOREACH(c, &gotd_client_uids[slot], entry) {
174 if (c->euid == euid)
175 return c;
178 return NULL;
181 static const struct got_error *
182 disconnect(struct gotd_listen_client *client)
184 struct gotd_uid_connection_counter *counter;
185 uint64_t slot;
186 int client_fd;
188 log_debug("client on fd %d disconnecting", client->fd);
190 slot = client_hash(client->id) % nitems(gotd_listen_clients);
191 STAILQ_REMOVE(&gotd_listen_clients[slot], client,
192 gotd_listen_client, entry);
194 counter = find_uid_connection_counter(client->euid);
195 if (counter) {
196 if (counter->nconnections > 0)
197 counter->nconnections--;
198 if (counter->nconnections == 0) {
199 remove_uid_connection_counter(counter);
200 free(counter);
204 client_fd = client->fd;
205 free(client);
206 inflight--;
207 listen_client_cnt--;
208 if (close(client_fd) == -1)
209 return got_error_from_errno("close");
211 return NULL;
214 static int
215 accept_reserve(int fd, struct sockaddr *addr, socklen_t *addrlen,
216 int reserve, volatile int *counter)
218 int ret;
220 if (getdtablecount() + reserve +
221 ((*counter + 1) * GOTD_FD_NEEDED) >= getdtablesize()) {
222 log_debug("inflight fds exceeded");
223 errno = EMFILE;
224 return -1;
227 if ((ret = accept4(fd, addr, addrlen,
228 SOCK_NONBLOCK | SOCK_CLOEXEC)) > -1) {
229 (*counter)++;
232 return ret;
235 static void
236 gotd_accept_paused(int fd, short event, void *arg)
238 event_add(&gotd_listen.iev.ev, NULL);
241 static void
242 gotd_accept(int fd, short event, void *arg)
244 struct gotd_imsgev *iev = arg;
245 struct sockaddr_storage ss;
246 struct timeval backoff;
247 socklen_t len;
248 int s = -1;
249 struct gotd_listen_client *client = NULL;
250 struct gotd_uid_connection_counter *counter = NULL;
251 struct gotd_imsg_connect iconn;
252 uid_t euid;
253 gid_t egid;
255 backoff.tv_sec = 1;
256 backoff.tv_usec = 0;
258 if (event_add(&gotd_listen.iev.ev, NULL) == -1) {
259 log_warn("event_add");
260 return;
262 if (event & EV_TIMEOUT)
263 return;
265 len = sizeof(ss);
267 /* Other backoff conditions apart from EMFILE/ENFILE? */
268 s = accept_reserve(fd, (struct sockaddr *)&ss, &len, GOTD_FD_RESERVE,
269 &inflight);
270 if (s == -1) {
271 switch (errno) {
272 case EINTR:
273 case EWOULDBLOCK:
274 case ECONNABORTED:
275 return;
276 case EMFILE:
277 case ENFILE:
278 event_del(&gotd_listen.iev.ev);
279 evtimer_add(&gotd_listen.pause.ev, &backoff);
280 return;
281 default:
282 log_warn("accept");
283 return;
287 if (listen_client_cnt >= GOTD_MAXCLIENTS)
288 goto err;
290 if (getpeereid(s, &euid, &egid) == -1) {
291 log_warn("getpeerid");
292 goto err;
295 counter = find_uid_connection_counter(euid);
296 if (counter == NULL) {
297 counter = calloc(1, sizeof(*counter));
298 if (counter == NULL) {
299 log_warn("%s: calloc", __func__);
300 goto err;
302 counter->euid = euid;
303 counter->nconnections = 1;
304 add_uid_connection_counter(counter);
305 } else {
306 if (counter->nconnections >= GOTD_MAX_CONN_PER_UID) {
307 log_warnx("maximum connections exceeded for uid %d",
308 euid);
309 goto err;
311 counter->nconnections++;
314 client = calloc(1, sizeof(*client));
315 if (client == NULL) {
316 log_warn("%s: calloc", __func__);
317 goto err;
319 client->id = get_client_id();
320 client->fd = s;
321 client->euid = euid;
322 s = -1;
323 add_client(client);
324 log_debug("%s: new client connected on fd %d uid %d gid %d", __func__,
325 client->fd, euid, egid);
327 memset(&iconn, 0, sizeof(iconn));
328 iconn.client_id = client->id;
329 iconn.euid = euid;
330 iconn.egid = egid;
331 s = dup(client->fd);
332 if (s == -1) {
333 log_warn("%s: dup", __func__);
334 goto err;
336 if (gotd_imsg_compose_event(iev, GOTD_IMSG_CONNECT, PROC_LISTEN, s,
337 &iconn, sizeof(iconn)) == -1) {
338 log_warn("imsg compose CONNECT");
339 goto err;
342 return;
343 err:
344 inflight--;
345 if (client)
346 disconnect(client);
347 if (s != -1)
348 close(s);
351 static const struct got_error *
352 recv_disconnect(struct imsg *imsg)
354 struct gotd_imsg_disconnect idisconnect;
355 size_t datalen;
356 struct gotd_listen_client *client = NULL;
358 datalen = imsg->hdr.len - IMSG_HEADER_SIZE;
359 if (datalen != sizeof(idisconnect))
360 return got_error(GOT_ERR_PRIVSEP_LEN);
361 memcpy(&idisconnect, imsg->data, sizeof(idisconnect));
363 log_debug("client disconnecting");
365 client = find_client(idisconnect.client_id);
366 if (client == NULL)
367 return got_error(GOT_ERR_CLIENT_ID);
369 return disconnect(client);
372 static void
373 listen_dispatch(int fd, short event, void *arg)
375 const struct got_error *err = NULL;
376 struct gotd_imsgev *iev = arg;
377 struct imsgbuf *ibuf = &iev->ibuf;
378 struct imsg imsg;
379 ssize_t n;
380 int shut = 0;
382 if (event & EV_READ) {
383 if ((n = imsg_read(ibuf)) == -1 && errno != EAGAIN)
384 fatal("imsg_read error");
385 if (n == 0) /* Connection closed. */
386 shut = 1;
389 if (event & EV_WRITE) {
390 n = msgbuf_write(&ibuf->w);
391 if (n == -1 && errno != EAGAIN)
392 fatal("msgbuf_write");
393 if (n == 0) /* Connection closed. */
394 shut = 1;
397 for (;;) {
398 if ((n = imsg_get(ibuf, &imsg)) == -1)
399 fatal("%s: imsg_get", __func__);
400 if (n == 0) /* No more messages. */
401 break;
403 switch (imsg.hdr.type) {
404 case GOTD_IMSG_DISCONNECT:
405 err = recv_disconnect(&imsg);
406 if (err)
407 log_warnx("%s: disconnect: %s",
408 gotd_listen.title, err->msg);
409 break;
410 default:
411 log_debug("%s: unexpected imsg %d", gotd_listen.title,
412 imsg.hdr.type);
413 break;
416 imsg_free(&imsg);
419 if (!shut) {
420 gotd_imsg_event_add(iev);
421 } else {
422 /* This pipe is dead. Remove its event handler */
423 event_del(&iev->ev);
424 event_loopexit(NULL);
428 void
429 listen_main(const char *title, int gotd_socket)
431 struct gotd_imsgev iev;
432 struct event evsigint, evsigterm, evsighup, evsigusr1;
434 arc4random_buf(&clients_hash_key, sizeof(clients_hash_key));
435 arc4random_buf(&uid_hash_key, sizeof(uid_hash_key));
437 gotd_listen.title = title;
438 gotd_listen.pid = getpid();
439 gotd_listen.fd = gotd_socket;
441 signal_set(&evsigint, SIGINT, listen_sighdlr, NULL);
442 signal_set(&evsigterm, SIGTERM, listen_sighdlr, NULL);
443 signal_set(&evsighup, SIGHUP, listen_sighdlr, NULL);
444 signal_set(&evsigusr1, SIGUSR1, listen_sighdlr, NULL);
445 signal(SIGPIPE, SIG_IGN);
447 signal_add(&evsigint, NULL);
448 signal_add(&evsigterm, NULL);
449 signal_add(&evsighup, NULL);
450 signal_add(&evsigusr1, NULL);
452 imsg_init(&iev.ibuf, GOTD_FILENO_MSG_PIPE);
453 iev.handler = listen_dispatch;
454 iev.events = EV_READ;
455 iev.handler_arg = NULL;
456 event_set(&iev.ev, iev.ibuf.fd, EV_READ, listen_dispatch, &iev);
457 if (event_add(&iev.ev, NULL) == -1)
458 fatalx("event add");
460 event_set(&gotd_listen.iev.ev, gotd_listen.fd, EV_READ | EV_PERSIST,
461 gotd_accept, &iev);
462 if (event_add(&gotd_listen.iev.ev, NULL))
463 fatalx("event add");
464 evtimer_set(&gotd_listen.pause.ev, gotd_accept_paused, NULL);
466 event_dispatch();
468 listen_shutdown();
471 static void
472 listen_shutdown(void)
474 log_debug("%s: shutting down", gotd_listen.title);
476 if (gotd_listen.fd != -1)
477 close(gotd_listen.fd);
479 exit(0);