Blob


1 #include <u.h>
2 #include <libc.h>
3 #include <fcall.h>
4 #include <thread.h>
6 enum
7 {
8 STACK = 32768,
9 NHASH = 31,
10 MAXMSG = 64, /* per connection */
11 };
13 typedef struct Hash Hash;
14 typedef struct Fid Fid;
15 typedef struct Msg Msg;
16 typedef struct Conn Conn;
17 typedef struct Queue Queue;
19 struct Hash
20 {
21 Hash *next;
22 uint n;
23 void *v;
24 };
26 struct Fid
27 {
28 int fid;
29 int ref;
30 int cfid;
31 Fid *next;
32 };
34 struct Msg
35 {
36 Conn *c;
37 int internal;
38 int ref;
39 int ctag;
40 int tag;
41 Fcall tx;
42 Fcall rx;
43 Fid *fid;
44 Fid *newfid;
45 Fid *afid;
46 Msg *oldm;
47 Msg *next;
48 uchar *tpkt;
49 uchar *rpkt;
50 };
52 struct Conn
53 {
54 int fd;
55 int nmsg;
56 int nfid;
57 Channel *inc;
58 Channel *internal;
59 int inputstalled;
60 char dir[40];
61 Hash *tag[NHASH];
62 Hash *fid[NHASH];
63 Queue *outq;
64 Queue *inq;
65 };
67 char *addr;
68 int afd;
69 char adir[40];
70 int isunix;
71 Queue *outq;
72 Queue *inq;
74 void *gethash(Hash**, uint);
75 int puthash(Hash**, uint, void*);
76 int delhash(Hash**, uint, void*);
77 Msg *mread9p(int);
78 int mwrite9p(int, Msg*);
79 uchar *read9ppkt(int);
80 int write9ppkt(int, uchar*);
81 Msg *msgnew(void);
82 void msgput(Msg*);
83 Msg *msgget(int);
84 Fid *fidnew(int);
85 void fidput(Fid*);
86 void *emalloc(int);
87 void *erealloc(void*, int);
88 int sendq(Queue*, void*);
89 void *recvq(Queue*);
90 void selectthread(void*);
91 void connthread(void*);
92 void listenthread(void*);
93 void rewritehdr(Fcall*, uchar*);
94 int tlisten(char*, char*);
95 int taccept(int, char*);
97 void
98 usage(void)
99 {
100 fprint(2, "usage: 9pserve [-u] address\n");
101 fprint(2, "\treads/writes 9P messages on stdin/stdout\n");
102 exits("usage");
105 void
106 threadmain(int argc, char **argv)
108 ARGBEGIN{
109 default:
110 usage();
111 case 'u':
112 isunix = 1;
113 break;
114 }ARGEND
116 if(argc != 1)
117 usage();
119 if((afd = announce(addr, adir)) < 0)
120 sysfatal("announce %s: %r", addr);
122 threadcreateidle(selectthread, nil, STACK);
125 void
126 listenthread(void *arg)
128 Conn *c;
130 USED(arg);
131 for(;;){
132 c = malloc(sizeof(Conn));
133 if(c == nil){
134 fprint(2, "out of memory\n");
135 sleep(60*1000);
136 continue;
138 c->fd = tlisten(adir, c->dir);
139 if(c->fd < 0){
140 fprint(2, "listen: %r\n");
141 close(afd);
142 free(c);
143 return;
145 threadcreate(connthread, c, STACK);
149 void
150 err(Msg *m, char *ename)
152 int n, nn;
154 m->rx.type = Rerror;
155 m->rx.ename = ename;
156 m->rx.tag = m->ctag;
157 n = sizeS2M(&m->rx);
158 m->rpkt = emalloc(n);
159 nn = convS2M(&m->rx, m->rpkt, n);
160 if(nn != n)
161 sysfatal("sizeS2M + convS2M disagree");
162 sendq(m->c->outq, m);
165 void
166 connthread(void *arg)
168 int i, fd;
169 Conn *c;
170 Hash *h;
171 Msg *m, *om;
172 Fid *f;
174 c = arg;
175 fd = taccept(c->fd, c->dir);
176 if(fd < 0){
177 fprint(2, "accept %s: %r\n", c->dir);
178 goto out;
180 close(c->fd);
181 c->fd = fd;
182 while((m = mread9p(c->fd)) != nil){
183 m->c = c;
184 c->nmsg++;
185 if(puthash(c->tag, m->tx.tag, m) < 0){
186 err(m, "duplicate tag");
187 continue;
189 switch(m->tx.type){
190 case Tflush:
191 if((m->oldm = gethash(c->tag, m->tx.oldtag)) == nil){
192 m->rx.tag = Rflush;
193 sendq(c->outq, m);
194 continue;
196 break;
197 case Tattach:
198 m->fid = fidnew(m->tx.fid);
199 if(puthash(c->fid, m->tx.fid, m->fid) < 0){
200 err(m, "duplicate fid");
201 continue;
203 m->fid->ref++;
204 break;
205 case Twalk:
206 if((m->fid = gethash(c->fid, m->tx.fid)) == nil){
207 err(m, "unknown fid");
208 continue;
210 if(m->tx.newfid == m->tx.fid){
211 m->fid->ref++;
212 m->newfid = m->fid;
213 }else{
214 m->newfid = fidnew(m->tx.newfid);
215 if(puthash(c->fid, m->tx.newfid, m->newfid) < 0){
216 err(m, "duplicate fid");
217 continue;
219 m->newfid->ref++;
221 break;
222 case Tauth:
223 if((m->afid = gethash(c->fid, m->tx.afid)) == nil){
224 err(m, "unknown fid");
225 continue;
227 m->fid = fidnew(m->tx.fid);
228 if(puthash(c->fid, m->tx.fid, m->fid) < 0){
229 err(m, "duplicate fid");
230 continue;
232 m->fid->ref++;
233 break;
234 case Topen:
235 case Tclunk:
236 case Tread:
237 case Twrite:
238 case Tstat:
239 case Twstat:
240 if((m->fid = gethash(c->fid, m->tx.fid)) == nil){
241 err(m, "unknown fid");
242 continue;
244 m->fid->ref++;
245 break;
248 /* have everything - translate and send */
249 m->c = c;
250 m->ctag = m->tx.tag;
251 m->tx.tag = m->tag;
252 if(m->fid)
253 m->tx.fid = m->fid->fid;
254 if(m->newfid)
255 m->tx.newfid = m->newfid->fid;
256 if(m->afid)
257 m->tx.afid = m->afid->fid;
258 if(m->oldm)
259 m->tx.oldtag = m->oldm->tag;
260 rewritehdr(&m->tx, m->tpkt);
261 sendq(outq, m);
262 while(c->nmsg >= MAXMSG){
263 c->inputstalled = 1;
264 recvp(c->inc);
268 /* flush all outstanding messages */
269 for(i=0; i<NHASH; i++){
270 for(h=c->tag[i]; h; h=h->next){
271 om = h->v;
272 m = msgnew();
273 m->internal = 1;
274 m->c = c;
275 m->tx.type = Tflush;
276 m->tx.tag = m->tag;
277 m->tx.oldtag = om->tag;
278 m->oldm = om;
279 om->ref++;
280 sendq(outq, m);
281 recvp(c->internal);
285 /* clunk all outstanding fids */
286 for(i=0; i<NHASH; i++){
287 for(h=c->fid[i]; h; h=h->next){
288 f = h->v;
289 m = msgnew();
290 m->internal = 1;
291 m->c = c;
292 m->tx.type = Tclunk;
293 m->tx.tag = m->tag;
294 m->tx.fid = f->fid;
295 m->fid = f;
296 f->ref++;
297 sendq(outq, m);
298 recvp(c->internal);
302 out:
303 assert(c->nmsg == 0);
304 assert(c->nfid == 0);
305 close(c->fd);
306 free(c);
309 void
310 connoutthread(void *arg)
312 int err;
313 Conn *c;
314 Msg *m, *om;
316 c = arg;
317 while((m = recvq(c->outq)) != nil){
318 err = m->tx.type+1 != m->rx.type;
319 switch(m->tx.type){
320 case Tflush:
321 om = m->oldm;
322 if(delhash(om->c->tag, om->ctag, om) == 0)
323 msgput(om);
324 break;
325 case Tclunk:
326 if(delhash(m->c->fid, m->fid->cfid, m->fid) == 0)
327 fidput(m->fid);
328 break;
329 case Tauth:
330 if(err)
331 if(delhash(m->c->fid, m->afid->cfid, m->fid) == 0)
332 fidput(m->fid);
333 case Tattach:
334 if(err)
335 if(delhash(m->c->fid, m->fid->cfid, m->fid) == 0)
336 fidput(m->fid);
337 break;
338 case Twalk:
339 if(err && m->tx.fid != m->tx.newfid)
340 if(delhash(m->c->fid, m->newfid->cfid, m->newfid) == 0)
341 fidput(m->newfid);
342 break;
344 if(mwrite9p(c->fd, m) < 0)
345 fprint(2, "write error: %r\n");
346 if(delhash(m->c->tag, m->ctag, m) == 0)
347 msgput(m);
348 msgput(m);
349 if(c->inputstalled && c->nmsg < MAXMSG)
350 nbsendp(c->inc, 0);
354 void
355 outputthread(void *arg)
357 Msg *m;
359 USED(arg);
361 while((m = recvq(outq)) != nil){
362 if(mwrite9p(1, m) < 0)
363 sysfatal("output error: %r");
364 msgput(m);
368 void
369 inputthread(void *arg)
371 uchar *pkt;
372 int n, nn, tag;
373 Msg *m;
375 while((pkt = read9ppkt(0)) != nil){
376 n = GBIT32(pkt);
377 if(n < 7){
378 fprint(2, "short 9P packet\n");
379 free(pkt);
380 continue;
382 tag = GBIT16(pkt+5);
383 if((m = msgget(tag)) == nil){
384 fprint(2, "unexpected 9P response tag %d\n", tag);
385 free(pkt);
386 msgput(m);
387 continue;
389 if((nn = convM2S(pkt, n, &m->rx)) != n){
390 fprint(2, "bad packet - convM2S %d but %d\n", nn, n);
391 free(pkt);
392 msgput(m);
393 continue;
395 m->rpkt = pkt;
396 m->rx.tag = m->ctag;
397 rewritehdr(&m->rx, m->rpkt);
398 sendq(m->c->outq, m);
402 void*
403 gethash(Hash **ht, uint n)
405 Hash *h;
407 for(h=ht[n%NHASH]; h; h=h->next)
408 if(h->n == n)
409 return h->v;
410 return nil;
413 int
414 delhash(Hash **ht, uint n, void *v)
416 Hash *h, **l;
418 for(l=&ht[n%NHASH]; h=*l; l=&h->next)
419 if(h->n == n){
420 if(h->v != v)
421 fprint(2, "hash error\n");
422 *l = h->next;
423 free(h);
424 return 0;
426 return -1;
429 int
430 puthash(Hash **ht, uint n, void *v)
432 Hash *h;
434 if(gethash(ht, n))
435 return -1;
436 h = emalloc(sizeof(Hash));
437 h->next = ht[n%NHASH];
438 h->n = n;
439 h->v = v;
440 ht[n%NHASH] = h;
441 return 0;
444 Fid **fidtab;
445 int nfidtab;
446 Fid *freefid;
448 Fid*
449 fidnew(int cfid)
451 Fid *f;
453 if(freefid == nil){
454 fidtab = erealloc(fidtab, nfidtab*sizeof(fidtab[0]));
455 fidtab[nfidtab] = emalloc(sizeof(Fid));
456 freefid = fidtab[nfidtab++];
458 f = freefid;
459 freefid = f->next;
460 f->cfid = f->cfid;
461 f->ref = 1;
462 return f;
465 void
466 fidput(Fid *f)
468 assert(f->ref > 0);
469 if(--f->ref > 0)
470 return;
471 f->next = freefid;
472 f->cfid = -1;
473 freefid = f;
476 Msg **msgtab;
477 int nmsgtab;
478 Msg *freemsg;
480 Msg*
481 msgnew(void)
483 Msg *m;
485 if(freemsg == nil){
486 msgtab = erealloc(msgtab, nmsgtab*sizeof(msgtab[0]));
487 msgtab[nmsgtab] = emalloc(sizeof(Msg));
488 freemsg = msgtab[nmsgtab++];
490 m = freemsg;
491 freemsg = m->next;
492 m->ref = 1;
493 return m;
496 void
497 msgput(Msg *m)
499 assert(m->ref > 0);
500 if(--m->ref > 0)
501 return;
502 m->next = freemsg;
503 freemsg = m;
506 void*
507 emalloc(int n)
509 void *v;
511 v = mallocz(n, 1);
512 if(v == nil)
513 sysfatal("out of memory");
514 return v;
517 void*
518 erealloc(void *v, int n)
520 v = realloc(v, n);
521 if(v == nil)
522 sysfatal("out of memory");
523 return v;
526 typedef struct Qel Qel;
527 struct Qel
529 Qel *next;
530 void *p;
531 };
533 struct Queue
535 int hungup;
536 QLock lk;
537 Rendez r;
538 Qel *head;
539 Qel *tail;
540 };
542 Queue*
543 qalloc(void)
545 Queue *q;
547 q = mallocz(sizeof(Queue), 1);
548 if(q == nil)
549 return nil;
550 q->r.l = &q->lk;
551 return q;
554 int
555 sendq(Queue *q, void *p)
557 Qel *e;
559 e = emalloc(sizeof(Qel));
560 qlock(&q->lk);
561 if(q->hungup){
562 werrstr("hungup queue");
563 qunlock(&q->lk);
564 return -1;
566 e->p = p;
567 e->next = nil;
568 if(q->head == nil)
569 q->head = e;
570 else
571 q->tail->next = e;
572 q->tail = e;
573 rwakeup(&q->r);
574 qunlock(&q->lk);
575 return 0;
578 void*
579 recvq(Queue *q)
581 void *p;
582 Qel *e;
584 qlock(&q->lk);
585 while(q->head == nil && !q->hungup)
586 rsleep(&q->r);
587 if(q->hungup){
588 qunlock(&q->lk);
589 return nil;
591 e = q->head;
592 q->head = e->next;
593 qunlock(&q->lk);
594 p = e->p;
595 free(e);
596 return p;