Blob


1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include <auth.h>
5 #include <mp.h>
6 #include <libsec.h>
8 /* The main groups of functions are: */
9 /* client/server - main handshake protocol definition */
10 /* message functions - formating handshake messages */
11 /* cipher choices - catalog of digest and encrypt algorithms */
12 /* security functions - PKCS#1, sslHMAC, session keygen */
13 /* general utility functions - malloc, serialization */
14 /* The handshake protocol builds on the TLS/SSL3 record layer protocol, */
15 /* which is implemented in kernel device #a. See also /lib/rfc/rfc2246. */
17 enum {
18 TLSFinishedLen = 12,
19 SSL3FinishedLen = MD5dlen+SHA1dlen,
20 MaxKeyData = 104, /* amount of secret we may need */
21 MaxChunk = 1<<14,
22 RandomSize = 32,
23 SidSize = 32,
24 MasterSecretSize = 48,
25 AQueue = 0,
26 AFlush = 1
27 };
29 typedef struct TlsSec TlsSec;
31 typedef struct Bytes{
32 int len;
33 uchar data[1]; /* [len] */
34 } Bytes;
36 typedef struct Ints{
37 int len;
38 int data[1]; /* [len] */
39 } Ints;
41 typedef struct Algs{
42 char *enc;
43 char *digest;
44 int nsecret;
45 int tlsid;
46 int ok;
47 } Algs;
49 typedef struct Finished{
50 uchar verify[SSL3FinishedLen];
51 int n;
52 } Finished;
54 typedef struct TlsConnection{
55 TlsSec *sec; /* security management goo */
56 int hand, ctl; /* record layer file descriptors */
57 int erred; /* set when tlsError called */
58 int (*trace)(char*fmt, ...); /* for debugging */
59 int version; /* protocol we are speaking */
60 int verset; /* version has been set */
61 int ver2hi; /* server got a version 2 hello */
62 int isClient; /* is this the client or server? */
63 Bytes *sid; /* SessionID */
64 Bytes *cert; /* only last - no chain */
66 Lock statelk;
67 int state; /* must be set using setstate */
69 /* input buffer for handshake messages */
70 uchar buf[MaxChunk+2048];
71 uchar *rp, *ep;
73 uchar crandom[RandomSize]; /* client random */
74 uchar srandom[RandomSize]; /* server random */
75 int clientVersion; /* version in ClientHello */
76 char *digest; /* name of digest algorithm to use */
77 char *enc; /* name of encryption algorithm to use */
78 int nsecret; /* amount of secret data to init keys */
80 /* for finished messages */
81 MD5state hsmd5; /* handshake hash */
82 SHAstate hssha1; /* handshake hash */
83 Finished finished;
84 } TlsConnection;
86 typedef struct Msg{
87 int tag;
88 union {
89 struct {
90 int version;
91 uchar random[RandomSize];
92 Bytes* sid;
93 Ints* ciphers;
94 Bytes* compressors;
95 } clientHello;
96 struct {
97 int version;
98 uchar random[RandomSize];
99 Bytes* sid;
100 int cipher;
101 int compressor;
102 } serverHello;
103 struct {
104 int ncert;
105 Bytes **certs;
106 } certificate;
107 struct {
108 Bytes *types;
109 int nca;
110 Bytes **cas;
111 } certificateRequest;
112 struct {
113 Bytes *key;
114 } clientKeyExchange;
115 Finished finished;
116 } u;
117 } Msg;
119 struct TlsSec{
120 char *server; /* name of remote; nil for server */
121 int ok; /* <0 killed; ==0 in progress; >0 reusable */
122 RSApub *rsapub;
123 AuthRpc *rpc; /* factotum for rsa private key */
124 uchar sec[MasterSecretSize]; /* master secret */
125 uchar crandom[RandomSize]; /* client random */
126 uchar srandom[RandomSize]; /* server random */
127 int clientVers; /* version in ClientHello */
128 int vers; /* final version */
129 /* byte generation and handshake checksum */
130 void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
131 void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int);
132 int nfin;
133 };
136 enum {
137 TLSVersion = 0x0301,
138 SSL3Version = 0x0300,
139 ProtocolVersion = 0x0301, /* maximum version we speak */
140 MinProtoVersion = 0x0300, /* limits on version we accept */
141 MaxProtoVersion = 0x03ff
142 };
144 /* handshake type */
145 enum {
146 HHelloRequest,
147 HClientHello,
148 HServerHello,
149 HSSL2ClientHello = 9, /* local convention; see devtls.c */
150 HCertificate = 11,
151 HServerKeyExchange,
152 HCertificateRequest,
153 HServerHelloDone,
154 HCertificateVerify,
155 HClientKeyExchange,
156 HFinished = 20,
157 HMax
158 };
160 /* alerts */
161 enum {
162 ECloseNotify = 0,
163 EUnexpectedMessage = 10,
164 EBadRecordMac = 20,
165 EDecryptionFailed = 21,
166 ERecordOverflow = 22,
167 EDecompressionFailure = 30,
168 EHandshakeFailure = 40,
169 ENoCertificate = 41,
170 EBadCertificate = 42,
171 EUnsupportedCertificate = 43,
172 ECertificateRevoked = 44,
173 ECertificateExpired = 45,
174 ECertificateUnknown = 46,
175 EIllegalParameter = 47,
176 EUnknownCa = 48,
177 EAccessDenied = 49,
178 EDecodeError = 50,
179 EDecryptError = 51,
180 EExportRestriction = 60,
181 EProtocolVersion = 70,
182 EInsufficientSecurity = 71,
183 EInternalError = 80,
184 EUserCanceled = 90,
185 ENoRenegotiation = 100,
186 EMax = 256
187 };
189 /* cipher suites */
190 enum {
191 TLS_NULL_WITH_NULL_NULL = 0x0000,
192 TLS_RSA_WITH_NULL_MD5 = 0x0001,
193 TLS_RSA_WITH_NULL_SHA = 0x0002,
194 TLS_RSA_EXPORT_WITH_RC4_40_MD5 = 0x0003,
195 TLS_RSA_WITH_RC4_128_MD5 = 0x0004,
196 TLS_RSA_WITH_RC4_128_SHA = 0x0005,
197 TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 = 0X0006,
198 TLS_RSA_WITH_IDEA_CBC_SHA = 0X0007,
199 TLS_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X0008,
200 TLS_RSA_WITH_DES_CBC_SHA = 0X0009,
201 TLS_RSA_WITH_3DES_EDE_CBC_SHA = 0X000A,
202 TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA = 0X000B,
203 TLS_DH_DSS_WITH_DES_CBC_SHA = 0X000C,
204 TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA = 0X000D,
205 TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X000E,
206 TLS_DH_RSA_WITH_DES_CBC_SHA = 0X000F,
207 TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA = 0X0010,
208 TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA = 0X0011,
209 TLS_DHE_DSS_WITH_DES_CBC_SHA = 0X0012,
210 TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA = 0X0013, /* ZZZ must be implemented for tls1.0 compliance */
211 TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X0014,
212 TLS_DHE_RSA_WITH_DES_CBC_SHA = 0X0015,
213 TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA = 0X0016,
214 TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 = 0x0017,
215 TLS_DH_anon_WITH_RC4_128_MD5 = 0x0018,
216 TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA = 0X0019,
217 TLS_DH_anon_WITH_DES_CBC_SHA = 0X001A,
218 TLS_DH_anon_WITH_3DES_EDE_CBC_SHA = 0X001B,
220 TLS_RSA_WITH_AES_128_CBC_SHA = 0X002f, /* aes, aka rijndael with 128 bit blocks */
221 TLS_DH_DSS_WITH_AES_128_CBC_SHA = 0X0030,
222 TLS_DH_RSA_WITH_AES_128_CBC_SHA = 0X0031,
223 TLS_DHE_DSS_WITH_AES_128_CBC_SHA = 0X0032,
224 TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0X0033,
225 TLS_DH_anon_WITH_AES_128_CBC_SHA = 0X0034,
226 TLS_RSA_WITH_AES_256_CBC_SHA = 0X0035,
227 TLS_DH_DSS_WITH_AES_256_CBC_SHA = 0X0036,
228 TLS_DH_RSA_WITH_AES_256_CBC_SHA = 0X0037,
229 TLS_DHE_DSS_WITH_AES_256_CBC_SHA = 0X0038,
230 TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0X0039,
231 TLS_DH_anon_WITH_AES_256_CBC_SHA = 0X003A,
232 CipherMax
233 };
235 /* compression methods */
236 enum {
237 CompressionNull = 0,
238 CompressionMax
239 };
241 static Algs cipherAlgs[] = {
242 {"rc4_128", "md5", 2 * (16 + MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
243 {"rc4_128", "sha1", 2 * (16 + SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
244 {"3des_ede_cbc","sha1",2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
245 };
247 static uchar compressors[] = {
248 CompressionNull,
249 };
251 static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chain);
252 static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...));
254 static void msgClear(Msg *m);
255 static char* msgPrint(char *buf, int n, Msg *m);
256 static int msgRecv(TlsConnection *c, Msg *m);
257 static int msgSend(TlsConnection *c, Msg *m, int act);
258 static void tlsError(TlsConnection *c, int err, char *msg, ...);
259 /* #pragma varargck argpos tlsError 3*/
260 static int setVersion(TlsConnection *c, int version);
261 static int finishedMatch(TlsConnection *c, Finished *f);
262 static void tlsConnectionFree(TlsConnection *c);
264 static int setAlgs(TlsConnection *c, int a);
265 static int okCipher(Ints *cv);
266 static int okCompression(Bytes *cv);
267 static int initCiphers(void);
268 static Ints* makeciphers(void);
270 static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
271 static int tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd);
272 static TlsSec* tlsSecInitc(int cvers, uchar *crandom);
273 static int tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd);
274 static int tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
275 static void tlsSecOk(TlsSec *sec);
276 /* static void tlsSecKill(TlsSec *sec); */
277 static void tlsSecClose(TlsSec *sec);
278 static void setMasterSecret(TlsSec *sec, Bytes *pm);
279 static void serverMasterSecret(TlsSec *sec, uchar *epm, int nepm);
280 static void setSecrets(TlsSec *sec, uchar *kd, int nkd);
281 static int clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm);
282 static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
283 static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm);
284 static void tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
285 static void sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
286 static void sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
287 uchar *seed0, int nseed0, uchar *seed1, int nseed1);
288 static int setVers(TlsSec *sec, int version);
290 static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
291 static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
292 static void factotum_rsa_close(AuthRpc*rpc);
294 static void* emalloc(int);
295 static void* erealloc(void*, int);
296 static void put32(uchar *p, u32int);
297 static void put24(uchar *p, int);
298 static void put16(uchar *p, int);
299 /* static u32int get32(uchar *p); */
300 static int get24(uchar *p);
301 static int get16(uchar *p);
302 static Bytes* newbytes(int len);
303 static Bytes* makebytes(uchar* buf, int len);
304 static void freebytes(Bytes* b);
305 static Ints* newints(int len);
306 /* static Ints* makeints(int* buf, int len); */
307 static void freeints(Ints* b);
309 /*================= client/server ======================== */
311 /* push TLS onto fd, returning new (application) file descriptor */
312 /* or -1 if error. */
313 int
314 tlsServer(int fd, TLSconn *conn)
316 char buf[8];
317 char dname[64];
318 int n, data, ctl, hand;
319 TlsConnection *tls;
321 if(conn == nil)
322 return -1;
323 ctl = open("#a/tls/clone", ORDWR);
324 if(ctl < 0)
325 return -1;
326 n = read(ctl, buf, sizeof(buf)-1);
327 if(n < 0){
328 close(ctl);
329 return -1;
331 buf[n] = 0;
332 sprint(conn->dir, "#a/tls/%s", buf);
333 sprint(dname, "#a/tls/%s/hand", buf);
334 hand = open(dname, ORDWR);
335 if(hand < 0){
336 close(ctl);
337 return -1;
339 fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
340 tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain);
341 sprint(dname, "#a/tls/%s/data", buf);
342 data = open(dname, ORDWR);
343 close(fd);
344 close(hand);
345 close(ctl);
346 if(data < 0){
347 return -1;
349 if(tls == nil){
350 close(data);
351 return -1;
353 if(conn->cert)
354 free(conn->cert);
355 conn->cert = 0; /* client certificates are not yet implemented */
356 conn->certlen = 0;
357 conn->sessionIDlen = tls->sid->len;
358 conn->sessionID = emalloc(conn->sessionIDlen);
359 memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
360 tlsConnectionFree(tls);
361 return data;
364 /* push TLS onto fd, returning new (application) file descriptor */
365 /* or -1 if error. */
366 int
367 tlsClient(int fd, TLSconn *conn)
369 char buf[8];
370 char dname[64];
371 int n, data, ctl, hand;
372 TlsConnection *tls;
374 if(!conn)
375 return -1;
376 ctl = open("#a/tls/clone", ORDWR);
377 if(ctl < 0)
378 return -1;
379 n = read(ctl, buf, sizeof(buf)-1);
380 if(n < 0){
381 close(ctl);
382 return -1;
384 buf[n] = 0;
385 sprint(conn->dir, "#a/tls/%s", buf);
386 sprint(dname, "#a/tls/%s/hand", buf);
387 hand = open(dname, ORDWR);
388 if(hand < 0){
389 close(ctl);
390 return -1;
392 sprint(dname, "#a/tls/%s/data", buf);
393 data = open(dname, ORDWR);
394 if(data < 0)
395 return -1;
396 fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
397 tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->trace);
398 close(fd);
399 close(hand);
400 close(ctl);
401 if(tls == nil){
402 close(data);
403 return -1;
405 conn->certlen = tls->cert->len;
406 conn->cert = emalloc(conn->certlen);
407 memcpy(conn->cert, tls->cert->data, conn->certlen);
408 conn->sessionIDlen = tls->sid->len;
409 conn->sessionID = emalloc(conn->sessionIDlen);
410 memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
411 tlsConnectionFree(tls);
412 return data;
415 static int
416 countchain(PEMChain *p)
418 int i = 0;
420 while (p) {
421 i++;
422 p = p->next;
424 return i;
427 static TlsConnection *
428 tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chp)
430 TlsConnection *c;
431 Msg m;
432 Bytes *csid;
433 uchar sid[SidSize], kd[MaxKeyData];
434 char *secrets;
435 int cipher, compressor, nsid, rv, numcerts, i;
437 if(trace)
438 trace("tlsServer2\n");
439 if(!initCiphers())
440 return nil;
441 c = emalloc(sizeof(TlsConnection));
442 c->ctl = ctl;
443 c->hand = hand;
444 c->trace = trace;
445 c->version = ProtocolVersion;
447 memset(&m, 0, sizeof(m));
448 if(!msgRecv(c, &m)){
449 if(trace)
450 trace("initial msgRecv failed\n");
451 goto Err;
453 if(m.tag != HClientHello) {
454 tlsError(c, EUnexpectedMessage, "expected a client hello");
455 goto Err;
457 c->clientVersion = m.u.clientHello.version;
458 if(trace)
459 trace("ClientHello version %x\n", c->clientVersion);
460 if(setVersion(c, m.u.clientHello.version) < 0) {
461 tlsError(c, EIllegalParameter, "incompatible version");
462 goto Err;
465 memmove(c->crandom, m.u.clientHello.random, RandomSize);
466 cipher = okCipher(m.u.clientHello.ciphers);
467 if(cipher < 0) {
468 /* reply with EInsufficientSecurity if we know that's the case */
469 if(cipher == -2)
470 tlsError(c, EInsufficientSecurity, "cipher suites too weak");
471 else
472 tlsError(c, EHandshakeFailure, "no matching cipher suite");
473 goto Err;
475 if(!setAlgs(c, cipher)){
476 tlsError(c, EHandshakeFailure, "no matching cipher suite");
477 goto Err;
479 compressor = okCompression(m.u.clientHello.compressors);
480 if(compressor < 0) {
481 tlsError(c, EHandshakeFailure, "no matching compressor");
482 goto Err;
485 csid = m.u.clientHello.sid;
486 if(trace)
487 trace(" cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
488 c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
489 if(c->sec == nil){
490 tlsError(c, EHandshakeFailure, "can't initialize security: %r");
491 goto Err;
493 c->sec->rpc = factotum_rsa_open(cert, ncert);
494 if(c->sec->rpc == nil){
495 tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
496 goto Err;
498 c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0);
499 msgClear(&m);
501 m.tag = HServerHello;
502 m.u.serverHello.version = c->version;
503 memmove(m.u.serverHello.random, c->srandom, RandomSize);
504 m.u.serverHello.cipher = cipher;
505 m.u.serverHello.compressor = compressor;
506 c->sid = makebytes(sid, nsid);
507 m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
508 if(!msgSend(c, &m, AQueue))
509 goto Err;
510 msgClear(&m);
512 m.tag = HCertificate;
513 numcerts = countchain(chp);
514 m.u.certificate.ncert = 1 + numcerts;
515 m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes));
516 m.u.certificate.certs[0] = makebytes(cert, ncert);
517 for (i = 0; i < numcerts && chp; i++, chp = chp->next)
518 m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
519 if(!msgSend(c, &m, AQueue))
520 goto Err;
521 msgClear(&m);
523 m.tag = HServerHelloDone;
524 if(!msgSend(c, &m, AFlush))
525 goto Err;
526 msgClear(&m);
528 if(!msgRecv(c, &m))
529 goto Err;
530 if(m.tag != HClientKeyExchange) {
531 tlsError(c, EUnexpectedMessage, "expected a client key exchange");
532 goto Err;
534 if(tlsSecSecrets(c->sec, c->version, m.u.clientKeyExchange.key->data, m.u.clientKeyExchange.key->len, kd, c->nsecret) < 0){
535 tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
536 goto Err;
538 if(trace)
539 trace("tls secrets\n");
540 secrets = (char*)emalloc(2*c->nsecret);
541 enc64(secrets, 2*c->nsecret, kd, c->nsecret);
542 rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
543 memset(secrets, 0, 2*c->nsecret);
544 free(secrets);
545 memset(kd, 0, c->nsecret);
546 if(rv < 0){
547 tlsError(c, EHandshakeFailure, "can't set keys: %r");
548 goto Err;
550 msgClear(&m);
552 /* no CertificateVerify; skip to Finished */
553 if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
554 tlsError(c, EInternalError, "can't set finished: %r");
555 goto Err;
557 if(!msgRecv(c, &m))
558 goto Err;
559 if(m.tag != HFinished) {
560 tlsError(c, EUnexpectedMessage, "expected a finished");
561 goto Err;
563 if(!finishedMatch(c, &m.u.finished)) {
564 tlsError(c, EHandshakeFailure, "finished verification failed");
565 goto Err;
567 msgClear(&m);
569 /* change cipher spec */
570 if(fprint(c->ctl, "changecipher") < 0){
571 tlsError(c, EInternalError, "can't enable cipher: %r");
572 goto Err;
575 if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
576 tlsError(c, EInternalError, "can't set finished: %r");
577 goto Err;
579 m.tag = HFinished;
580 m.u.finished = c->finished;
581 if(!msgSend(c, &m, AFlush))
582 goto Err;
583 msgClear(&m);
584 if(trace)
585 trace("tls finished\n");
587 if(fprint(c->ctl, "opened") < 0)
588 goto Err;
589 tlsSecOk(c->sec);
590 return c;
592 Err:
593 msgClear(&m);
594 tlsConnectionFree(c);
595 return 0;
598 static TlsConnection *
599 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...))
601 TlsConnection *c;
602 Msg m;
603 uchar kd[MaxKeyData], *epm;
604 char *secrets;
605 int creq, nepm, rv;
607 if(!initCiphers())
608 return nil;
609 epm = nil;
610 c = emalloc(sizeof(TlsConnection));
611 c->version = ProtocolVersion;
612 c->ctl = ctl;
613 c->hand = hand;
614 c->trace = trace;
615 c->isClient = 1;
616 c->clientVersion = c->version;
618 c->sec = tlsSecInitc(c->clientVersion, c->crandom);
619 if(c->sec == nil)
620 goto Err;
622 /* client hello */
623 memset(&m, 0, sizeof(m));
624 m.tag = HClientHello;
625 m.u.clientHello.version = c->clientVersion;
626 memmove(m.u.clientHello.random, c->crandom, RandomSize);
627 m.u.clientHello.sid = makebytes(csid, ncsid);
628 m.u.clientHello.ciphers = makeciphers();
629 m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
630 if(!msgSend(c, &m, AFlush))
631 goto Err;
632 msgClear(&m);
634 /* server hello */
635 if(!msgRecv(c, &m))
636 goto Err;
637 if(m.tag != HServerHello) {
638 tlsError(c, EUnexpectedMessage, "expected a server hello");
639 goto Err;
641 if(setVersion(c, m.u.serverHello.version) < 0) {
642 tlsError(c, EIllegalParameter, "incompatible version %r");
643 goto Err;
645 memmove(c->srandom, m.u.serverHello.random, RandomSize);
646 c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
647 if(c->sid->len != 0 && c->sid->len != SidSize) {
648 tlsError(c, EIllegalParameter, "invalid server session identifier");
649 goto Err;
651 if(!setAlgs(c, m.u.serverHello.cipher)) {
652 tlsError(c, EIllegalParameter, "invalid cipher suite");
653 goto Err;
655 if(m.u.serverHello.compressor != CompressionNull) {
656 tlsError(c, EIllegalParameter, "invalid compression");
657 goto Err;
659 msgClear(&m);
661 /* certificate */
662 if(!msgRecv(c, &m) || m.tag != HCertificate) {
663 tlsError(c, EUnexpectedMessage, "expected a certificate");
664 goto Err;
666 if(m.u.certificate.ncert < 1) {
667 tlsError(c, EIllegalParameter, "runt certificate");
668 goto Err;
670 c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
671 msgClear(&m);
673 /* server key exchange (optional) */
674 if(!msgRecv(c, &m))
675 goto Err;
676 if(m.tag == HServerKeyExchange) {
677 tlsError(c, EUnexpectedMessage, "got an server key exchange");
678 goto Err;
679 /* If implementing this later, watch out for rollback attack */
680 /* described in Wagner Schneier 1996, section 4.4. */
683 /* certificate request (optional) */
684 creq = 0;
685 if(m.tag == HCertificateRequest) {
686 creq = 1;
687 msgClear(&m);
688 if(!msgRecv(c, &m))
689 goto Err;
692 if(m.tag != HServerHelloDone) {
693 tlsError(c, EUnexpectedMessage, "expected a server hello done");
694 goto Err;
696 msgClear(&m);
698 if(tlsSecSecretc(c->sec, c->sid->data, c->sid->len, c->srandom,
699 c->cert->data, c->cert->len, c->version, &epm, &nepm,
700 kd, c->nsecret) < 0){
701 tlsError(c, EBadCertificate, "invalid x509/rsa certificate");
702 goto Err;
704 secrets = (char*)emalloc(2*c->nsecret);
705 enc64(secrets, 2*c->nsecret, kd, c->nsecret);
706 rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
707 memset(secrets, 0, 2*c->nsecret);
708 free(secrets);
709 memset(kd, 0, c->nsecret);
710 if(rv < 0){
711 tlsError(c, EHandshakeFailure, "can't set keys: %r");
712 goto Err;
715 if(creq) {
716 /* send a zero length certificate */
717 m.tag = HCertificate;
718 if(!msgSend(c, &m, AFlush))
719 goto Err;
720 msgClear(&m);
723 /* client key exchange */
724 m.tag = HClientKeyExchange;
725 m.u.clientKeyExchange.key = makebytes(epm, nepm);
726 free(epm);
727 epm = nil;
728 if(m.u.clientKeyExchange.key == nil) {
729 tlsError(c, EHandshakeFailure, "can't set secret: %r");
730 goto Err;
732 if(!msgSend(c, &m, AFlush))
733 goto Err;
734 msgClear(&m);
736 /* change cipher spec */
737 if(fprint(c->ctl, "changecipher") < 0){
738 tlsError(c, EInternalError, "can't enable cipher: %r");
739 goto Err;
742 /* Cipherchange must occur immediately before Finished to avoid */
743 /* potential hole; see section 4.3 of Wagner Schneier 1996. */
744 if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
745 tlsError(c, EInternalError, "can't set finished 1: %r");
746 goto Err;
748 m.tag = HFinished;
749 m.u.finished = c->finished;
751 if(!msgSend(c, &m, AFlush)) {
752 fprint(2, "tlsClient nepm=%d\n", nepm);
753 tlsError(c, EInternalError, "can't flush after client Finished: %r");
754 goto Err;
756 msgClear(&m);
758 if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
759 fprint(2, "tlsClient nepm=%d\n", nepm);
760 tlsError(c, EInternalError, "can't set finished 0: %r");
761 goto Err;
763 if(!msgRecv(c, &m)) {
764 fprint(2, "tlsClient nepm=%d\n", nepm);
765 tlsError(c, EInternalError, "can't read server Finished: %r");
766 goto Err;
768 if(m.tag != HFinished) {
769 fprint(2, "tlsClient nepm=%d\n", nepm);
770 tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
771 goto Err;
774 if(!finishedMatch(c, &m.u.finished)) {
775 tlsError(c, EHandshakeFailure, "finished verification failed");
776 goto Err;
778 msgClear(&m);
780 if(fprint(c->ctl, "opened") < 0){
781 if(trace)
782 trace("unable to do final open: %r\n");
783 goto Err;
785 tlsSecOk(c->sec);
786 return c;
788 Err:
789 free(epm);
790 msgClear(&m);
791 tlsConnectionFree(c);
792 return 0;
796 /*================= message functions ======================== */
798 static uchar sendbuf[9000], *sendp;
800 static int
801 msgSend(TlsConnection *c, Msg *m, int act)
803 uchar *p; /* sendp = start of new message; p = write pointer */
804 int nn, n, i;
806 if(sendp == nil)
807 sendp = sendbuf;
808 p = sendp;
809 if(c->trace)
810 c->trace("send %s", msgPrint((char*)p, (sizeof sendbuf) - (p-sendbuf), m));
812 p[0] = m->tag; /* header - fill in size later */
813 p += 4;
815 switch(m->tag) {
816 default:
817 tlsError(c, EInternalError, "can't encode a %d", m->tag);
818 goto Err;
819 case HClientHello:
820 /* version */
821 put16(p, m->u.clientHello.version);
822 p += 2;
824 /* random */
825 memmove(p, m->u.clientHello.random, RandomSize);
826 p += RandomSize;
828 /* sid */
829 n = m->u.clientHello.sid->len;
830 assert(n < 256);
831 p[0] = n;
832 memmove(p+1, m->u.clientHello.sid->data, n);
833 p += n+1;
835 n = m->u.clientHello.ciphers->len;
836 assert(n > 0 && n < 200);
837 put16(p, n*2);
838 p += 2;
839 for(i=0; i<n; i++) {
840 put16(p, m->u.clientHello.ciphers->data[i]);
841 p += 2;
844 n = m->u.clientHello.compressors->len;
845 assert(n > 0);
846 p[0] = n;
847 memmove(p+1, m->u.clientHello.compressors->data, n);
848 p += n+1;
849 break;
850 case HServerHello:
851 put16(p, m->u.serverHello.version);
852 p += 2;
854 /* random */
855 memmove(p, m->u.serverHello.random, RandomSize);
856 p += RandomSize;
858 /* sid */
859 n = m->u.serverHello.sid->len;
860 assert(n < 256);
861 p[0] = n;
862 memmove(p+1, m->u.serverHello.sid->data, n);
863 p += n+1;
865 put16(p, m->u.serverHello.cipher);
866 p += 2;
867 p[0] = m->u.serverHello.compressor;
868 p += 1;
869 break;
870 case HServerHelloDone:
871 break;
872 case HCertificate:
873 nn = 0;
874 for(i = 0; i < m->u.certificate.ncert; i++)
875 nn += 3 + m->u.certificate.certs[i]->len;
876 if(p + 3 + nn - sendbuf > sizeof(sendbuf)) {
877 tlsError(c, EInternalError, "output buffer too small for certificate");
878 goto Err;
880 put24(p, nn);
881 p += 3;
882 for(i = 0; i < m->u.certificate.ncert; i++){
883 put24(p, m->u.certificate.certs[i]->len);
884 p += 3;
885 memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
886 p += m->u.certificate.certs[i]->len;
888 break;
889 case HClientKeyExchange:
890 n = m->u.clientKeyExchange.key->len;
891 if(c->version != SSL3Version){
892 put16(p, n);
893 p += 2;
895 memmove(p, m->u.clientKeyExchange.key->data, n);
896 p += n;
897 break;
898 case HFinished:
899 memmove(p, m->u.finished.verify, m->u.finished.n);
900 p += m->u.finished.n;
901 break;
904 /* go back and fill in size */
905 n = p-sendp;
906 assert(p <= sendbuf+sizeof(sendbuf));
907 put24(sendp+1, n-4);
909 /* remember hash of Handshake messages */
910 if(m->tag != HHelloRequest) {
911 md5(sendp, n, 0, &c->hsmd5);
912 sha1(sendp, n, 0, &c->hssha1);
915 sendp = p;
916 if(act == AFlush){
917 sendp = sendbuf;
918 if(write(c->hand, sendbuf, p-sendbuf) < 0){
919 fprint(2, "write error: %r\n");
920 goto Err;
923 msgClear(m);
924 return 1;
925 Err:
926 msgClear(m);
927 return 0;
930 static uchar*
931 tlsReadN(TlsConnection *c, int n)
933 uchar *p;
934 int nn, nr;
936 nn = c->ep - c->rp;
937 if(nn < n){
938 if(c->rp != c->buf){
939 memmove(c->buf, c->rp, nn);
940 c->rp = c->buf;
941 c->ep = &c->buf[nn];
943 for(; nn < n; nn += nr) {
944 nr = read(c->hand, &c->rp[nn], n - nn);
945 if(nr <= 0)
946 return nil;
947 c->ep += nr;
950 p = c->rp;
951 c->rp += n;
952 return p;
955 static int
956 msgRecv(TlsConnection *c, Msg *m)
958 uchar *p;
959 int type, n, nn, i, nsid, nrandom, nciph;
961 for(;;) {
962 p = tlsReadN(c, 4);
963 if(p == nil)
964 return 0;
965 type = p[0];
966 n = get24(p+1);
968 if(type != HHelloRequest)
969 break;
970 if(n != 0) {
971 tlsError(c, EDecodeError, "invalid hello request during handshake");
972 return 0;
976 if(n > sizeof(c->buf)) {
977 tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->buf));
978 return 0;
981 if(type == HSSL2ClientHello){
982 /* Cope with an SSL3 ClientHello expressed in SSL2 record format.
983 This is sent by some clients that we must interoperate
984 with, such as Java's JSSE and Microsoft's Internet Explorer. */
985 p = tlsReadN(c, n);
986 if(p == nil)
987 return 0;
988 md5(p, n, 0, &c->hsmd5);
989 sha1(p, n, 0, &c->hssha1);
990 m->tag = HClientHello;
991 if(n < 22)
992 goto Short;
993 m->u.clientHello.version = get16(p+1);
994 p += 3;
995 n -= 3;
996 nn = get16(p); /* cipher_spec_len */
997 nsid = get16(p + 2);
998 nrandom = get16(p + 4);
999 p += 6;
1000 n -= 6;
1001 if(nsid != 0 /* no sid's, since shouldn't restart using ssl2 header */
1002 || nrandom < 16 || nn % 3)
1003 goto Err;
1004 if(c->trace && (n - nrandom != nn))
1005 c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
1006 /* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
1007 nciph = 0;
1008 for(i = 0; i < nn; i += 3)
1009 if(p[i] == 0)
1010 nciph++;
1011 m->u.clientHello.ciphers = newints(nciph);
1012 nciph = 0;
1013 for(i = 0; i < nn; i += 3)
1014 if(p[i] == 0)
1015 m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1016 p += nn;
1017 m->u.clientHello.sid = makebytes(nil, 0);
1018 if(nrandom > RandomSize)
1019 nrandom = RandomSize;
1020 memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1021 memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1022 m->u.clientHello.compressors = newbytes(1);
1023 m->u.clientHello.compressors->data[0] = CompressionNull;
1024 goto Ok;
1027 md5(p, 4, 0, &c->hsmd5);
1028 sha1(p, 4, 0, &c->hssha1);
1030 p = tlsReadN(c, n);
1031 if(p == nil)
1032 return 0;
1034 md5(p, n, 0, &c->hsmd5);
1035 sha1(p, n, 0, &c->hssha1);
1037 m->tag = type;
1039 switch(type) {
1040 default:
1041 tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1042 goto Err;
1043 case HClientHello:
1044 if(n < 2)
1045 goto Short;
1046 m->u.clientHello.version = get16(p);
1047 p += 2;
1048 n -= 2;
1050 if(n < RandomSize)
1051 goto Short;
1052 memmove(m->u.clientHello.random, p, RandomSize);
1053 p += RandomSize;
1054 n -= RandomSize;
1055 if(n < 1 || n < p[0]+1)
1056 goto Short;
1057 m->u.clientHello.sid = makebytes(p+1, p[0]);
1058 p += m->u.clientHello.sid->len+1;
1059 n -= m->u.clientHello.sid->len+1;
1061 if(n < 2)
1062 goto Short;
1063 nn = get16(p);
1064 p += 2;
1065 n -= 2;
1067 if((nn & 1) || n < nn || nn < 2)
1068 goto Short;
1069 m->u.clientHello.ciphers = newints(nn >> 1);
1070 for(i = 0; i < nn; i += 2)
1071 m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
1072 p += nn;
1073 n -= nn;
1075 if(n < 1 || n < p[0]+1 || p[0] == 0)
1076 goto Short;
1077 nn = p[0];
1078 m->u.clientHello.compressors = newbytes(nn);
1079 memmove(m->u.clientHello.compressors->data, p+1, nn);
1080 n -= nn + 1;
1081 break;
1082 case HServerHello:
1083 if(n < 2)
1084 goto Short;
1085 m->u.serverHello.version = get16(p);
1086 p += 2;
1087 n -= 2;
1089 if(n < RandomSize)
1090 goto Short;
1091 memmove(m->u.serverHello.random, p, RandomSize);
1092 p += RandomSize;
1093 n -= RandomSize;
1095 if(n < 1 || n < p[0]+1)
1096 goto Short;
1097 m->u.serverHello.sid = makebytes(p+1, p[0]);
1098 p += m->u.serverHello.sid->len+1;
1099 n -= m->u.serverHello.sid->len+1;
1101 if(n < 3)
1102 goto Short;
1103 m->u.serverHello.cipher = get16(p);
1104 m->u.serverHello.compressor = p[2];
1105 n -= 3;
1106 break;
1107 case HCertificate:
1108 if(n < 3)
1109 goto Short;
1110 nn = get24(p);
1111 p += 3;
1112 n -= 3;
1113 if(n != nn)
1114 goto Short;
1115 /* certs */
1116 i = 0;
1117 while(n > 0) {
1118 if(n < 3)
1119 goto Short;
1120 nn = get24(p);
1121 p += 3;
1122 n -= 3;
1123 if(nn > n)
1124 goto Short;
1125 m->u.certificate.ncert = i+1;
1126 m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes));
1127 m->u.certificate.certs[i] = makebytes(p, nn);
1128 p += nn;
1129 n -= nn;
1130 i++;
1132 break;
1133 case HCertificateRequest:
1134 if(n < 2)
1135 goto Short;
1136 nn = get16(p);
1137 p += 2;
1138 n -= 2;
1139 if(nn < 1 || nn > n)
1140 goto Short;
1141 m->u.certificateRequest.types = makebytes(p, nn);
1142 nn = get24(p);
1143 p += 3;
1144 n -= 3;
1145 if(nn == 0 || n != nn)
1146 goto Short;
1147 /* cas */
1148 i = 0;
1149 while(n > 0) {
1150 if(n < 2)
1151 goto Short;
1152 nn = get16(p);
1153 p += 2;
1154 n -= 2;
1155 if(nn < 1 || nn > n)
1156 goto Short;
1157 m->u.certificateRequest.nca = i+1;
1158 m->u.certificateRequest.cas = erealloc(m->u.certificateRequest.cas, (i+1)*sizeof(Bytes));
1159 m->u.certificateRequest.cas[i] = makebytes(p, nn);
1160 p += nn;
1161 n -= nn;
1162 i++;
1164 break;
1165 case HServerHelloDone:
1166 break;
1167 case HClientKeyExchange:
1169 * this message depends upon the encryption selected
1170 * assume rsa.
1172 if(c->version == SSL3Version)
1173 nn = n;
1174 else{
1175 if(n < 2)
1176 goto Short;
1177 nn = get16(p);
1178 p += 2;
1179 n -= 2;
1181 if(n < nn)
1182 goto Short;
1183 m->u.clientKeyExchange.key = makebytes(p, nn);
1184 n -= nn;
1185 break;
1186 case HFinished:
1187 m->u.finished.n = c->finished.n;
1188 if(n < m->u.finished.n)
1189 goto Short;
1190 memmove(m->u.finished.verify, p, m->u.finished.n);
1191 n -= m->u.finished.n;
1192 break;
1195 if(type != HClientHello && n != 0)
1196 goto Short;
1197 Ok:
1198 if(c->trace){
1199 char buf[8000];
1200 c->trace("recv %s", msgPrint(buf, sizeof buf, m));
1202 return 1;
1203 Short:
1204 tlsError(c, EDecodeError, "handshake message has invalid length");
1205 Err:
1206 msgClear(m);
1207 return 0;
1210 static void
1211 msgClear(Msg *m)
1213 int i;
1215 switch(m->tag) {
1216 default:
1217 sysfatal("msgClear: unknown message type: %d\n", m->tag);
1218 case HHelloRequest:
1219 break;
1220 case HClientHello:
1221 freebytes(m->u.clientHello.sid);
1222 freeints(m->u.clientHello.ciphers);
1223 freebytes(m->u.clientHello.compressors);
1224 break;
1225 case HServerHello:
1226 freebytes(m->u.clientHello.sid);
1227 break;
1228 case HCertificate:
1229 for(i=0; i<m->u.certificate.ncert; i++)
1230 freebytes(m->u.certificate.certs[i]);
1231 free(m->u.certificate.certs);
1232 break;
1233 case HCertificateRequest:
1234 freebytes(m->u.certificateRequest.types);
1235 for(i=0; i<m->u.certificateRequest.nca; i++)
1236 freebytes(m->u.certificateRequest.cas[i]);
1237 free(m->u.certificateRequest.cas);
1238 break;
1239 case HServerHelloDone:
1240 break;
1241 case HClientKeyExchange:
1242 freebytes(m->u.clientKeyExchange.key);
1243 break;
1244 case HFinished:
1245 break;
1247 memset(m, 0, sizeof(Msg));
1250 static char *
1251 bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1253 int i;
1255 if(s0)
1256 bs = seprint(bs, be, "%s", s0);
1257 bs = seprint(bs, be, "[");
1258 if(b == nil)
1259 bs = seprint(bs, be, "nil");
1260 else
1261 for(i=0; i<b->len; i++)
1262 bs = seprint(bs, be, "%.2x ", b->data[i]);
1263 bs = seprint(bs, be, "]");
1264 if(s1)
1265 bs = seprint(bs, be, "%s", s1);
1266 return bs;
1269 static char *
1270 intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1272 int i;
1274 if(s0)
1275 bs = seprint(bs, be, "%s", s0);
1276 bs = seprint(bs, be, "[");
1277 if(b == nil)
1278 bs = seprint(bs, be, "nil");
1279 else
1280 for(i=0; i<b->len; i++)
1281 bs = seprint(bs, be, "%x ", b->data[i]);
1282 bs = seprint(bs, be, "]");
1283 if(s1)
1284 bs = seprint(bs, be, "%s", s1);
1285 return bs;
1288 static char*
1289 msgPrint(char *buf, int n, Msg *m)
1291 int i;
1292 char *bs = buf, *be = buf+n;
1294 switch(m->tag) {
1295 default:
1296 bs = seprint(bs, be, "unknown %d\n", m->tag);
1297 break;
1298 case HClientHello:
1299 bs = seprint(bs, be, "ClientHello\n");
1300 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1301 bs = seprint(bs, be, "\trandom: ");
1302 for(i=0; i<RandomSize; i++)
1303 bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1304 bs = seprint(bs, be, "\n");
1305 bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1306 bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1307 bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
1308 break;
1309 case HServerHello:
1310 bs = seprint(bs, be, "ServerHello\n");
1311 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1312 bs = seprint(bs, be, "\trandom: ");
1313 for(i=0; i<RandomSize; i++)
1314 bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1315 bs = seprint(bs, be, "\n");
1316 bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1317 bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1318 bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
1319 break;
1320 case HCertificate:
1321 bs = seprint(bs, be, "Certificate\n");
1322 for(i=0; i<m->u.certificate.ncert; i++)
1323 bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1324 break;
1325 case HCertificateRequest:
1326 bs = seprint(bs, be, "CertificateRequest\n");
1327 bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
1328 bs = seprint(bs, be, "\tcertificateauthorities\n");
1329 for(i=0; i<m->u.certificateRequest.nca; i++)
1330 bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1331 break;
1332 case HServerHelloDone:
1333 bs = seprint(bs, be, "ServerHelloDone\n");
1334 break;
1335 case HClientKeyExchange:
1336 bs = seprint(bs, be, "HClientKeyExchange\n");
1337 bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
1338 break;
1339 case HFinished:
1340 bs = seprint(bs, be, "HFinished\n");
1341 for(i=0; i<m->u.finished.n; i++)
1342 bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1343 bs = seprint(bs, be, "\n");
1344 break;
1346 USED(bs);
1347 return buf;
1350 static void
1351 tlsError(TlsConnection *c, int err, char *fmt, ...)
1353 char msg[512];
1354 va_list arg;
1356 va_start(arg, fmt);
1357 vseprint(msg, msg+sizeof(msg), fmt, arg);
1358 va_end(arg);
1359 if(c->trace)
1360 c->trace("tlsError: %s\n", msg);
1361 else if(c->erred)
1362 fprint(2, "double error: %r, %s", msg);
1363 else
1364 werrstr("tls: local %s", msg);
1365 c->erred = 1;
1366 fprint(c->ctl, "alert %d", err);
1369 /* commit to specific version number */
1370 static int
1371 setVersion(TlsConnection *c, int version)
1373 if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
1374 return -1;
1375 if(version > c->version)
1376 version = c->version;
1377 if(version == SSL3Version) {
1378 c->version = version;
1379 c->finished.n = SSL3FinishedLen;
1380 }else if(version == TLSVersion){
1381 c->version = version;
1382 c->finished.n = TLSFinishedLen;
1383 }else
1384 return -1;
1385 c->verset = 1;
1386 return fprint(c->ctl, "version 0x%x", version);
1389 /* confirm that received Finished message matches the expected value */
1390 static int
1391 finishedMatch(TlsConnection *c, Finished *f)
1393 return memcmp(f->verify, c->finished.verify, f->n) == 0;
1396 /* free memory associated with TlsConnection struct */
1397 /* (but don't close the TLS channel itself) */
1398 static void
1399 tlsConnectionFree(TlsConnection *c)
1401 tlsSecClose(c->sec);
1402 freebytes(c->sid);
1403 freebytes(c->cert);
1404 memset(c, 0, sizeof(*c));
1405 free(c);
1409 /*================= cipher choices ======================== */
1411 static int weakCipher[CipherMax] =
1413 1, /* TLS_NULL_WITH_NULL_NULL */
1414 1, /* TLS_RSA_WITH_NULL_MD5 */
1415 1, /* TLS_RSA_WITH_NULL_SHA */
1416 1, /* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
1417 0, /* TLS_RSA_WITH_RC4_128_MD5 */
1418 0, /* TLS_RSA_WITH_RC4_128_SHA */
1419 1, /* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
1420 0, /* TLS_RSA_WITH_IDEA_CBC_SHA */
1421 1, /* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
1422 0, /* TLS_RSA_WITH_DES_CBC_SHA */
1423 0, /* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
1424 1, /* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
1425 0, /* TLS_DH_DSS_WITH_DES_CBC_SHA */
1426 0, /* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
1427 1, /* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
1428 0, /* TLS_DH_RSA_WITH_DES_CBC_SHA */
1429 0, /* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
1430 1, /* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
1431 0, /* TLS_DHE_DSS_WITH_DES_CBC_SHA */
1432 0, /* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
1433 1, /* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
1434 0, /* TLS_DHE_RSA_WITH_DES_CBC_SHA */
1435 0, /* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
1436 1, /* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
1437 1, /* TLS_DH_anon_WITH_RC4_128_MD5 */
1438 1, /* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
1439 1, /* TLS_DH_anon_WITH_DES_CBC_SHA */
1440 1, /* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
1443 static int
1444 setAlgs(TlsConnection *c, int a)
1446 int i;
1448 for(i = 0; i < nelem(cipherAlgs); i++){
1449 if(cipherAlgs[i].tlsid == a){
1450 c->enc = cipherAlgs[i].enc;
1451 c->digest = cipherAlgs[i].digest;
1452 c->nsecret = cipherAlgs[i].nsecret;
1453 if(c->nsecret > MaxKeyData)
1454 return 0;
1455 return 1;
1458 return 0;
1461 static int
1462 okCipher(Ints *cv)
1464 int weak, i, j, c;
1466 weak = 1;
1467 for(i = 0; i < cv->len; i++) {
1468 c = cv->data[i];
1469 if(c >= CipherMax)
1470 weak = 0;
1471 else
1472 weak &= weakCipher[c];
1473 for(j = 0; j < nelem(cipherAlgs); j++)
1474 if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
1475 return c;
1477 if(weak)
1478 return -2;
1479 return -1;
1482 static int
1483 okCompression(Bytes *cv)
1485 int i, j, c;
1487 for(i = 0; i < cv->len; i++) {
1488 c = cv->data[i];
1489 for(j = 0; j < nelem(compressors); j++) {
1490 if(compressors[j] == c)
1491 return c;
1494 return -1;
1497 static Lock ciphLock;
1498 static int nciphers;
1500 static int
1501 initCiphers(void)
1503 enum {MaxAlgF = 1024, MaxAlgs = 10};
1504 char s[MaxAlgF], *flds[MaxAlgs];
1505 int i, j, n, ok;
1507 lock(&ciphLock);
1508 if(nciphers){
1509 unlock(&ciphLock);
1510 return nciphers;
1512 j = open("#a/tls/encalgs", OREAD);
1513 if(j < 0){
1514 werrstr("can't open #a/tls/encalgs: %r");
1515 return 0;
1517 n = read(j, s, MaxAlgF-1);
1518 close(j);
1519 if(n <= 0){
1520 werrstr("nothing in #a/tls/encalgs: %r");
1521 return 0;
1523 s[n] = 0;
1524 n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1525 for(i = 0; i < nelem(cipherAlgs); i++){
1526 ok = 0;
1527 for(j = 0; j < n; j++){
1528 if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
1529 ok = 1;
1530 break;
1533 cipherAlgs[i].ok = ok;
1536 j = open("#a/tls/hashalgs", OREAD);
1537 if(j < 0){
1538 werrstr("can't open #a/tls/hashalgs: %r");
1539 return 0;
1541 n = read(j, s, MaxAlgF-1);
1542 close(j);
1543 if(n <= 0){
1544 werrstr("nothing in #a/tls/hashalgs: %r");
1545 return 0;
1547 s[n] = 0;
1548 n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1549 for(i = 0; i < nelem(cipherAlgs); i++){
1550 ok = 0;
1551 for(j = 0; j < n; j++){
1552 if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
1553 ok = 1;
1554 break;
1557 cipherAlgs[i].ok &= ok;
1558 if(cipherAlgs[i].ok)
1559 nciphers++;
1561 unlock(&ciphLock);
1562 return nciphers;
1565 static Ints*
1566 makeciphers(void)
1568 Ints *is;
1569 int i, j;
1571 is = newints(nciphers);
1572 j = 0;
1573 for(i = 0; i < nelem(cipherAlgs); i++){
1574 if(cipherAlgs[i].ok)
1575 is->data[j++] = cipherAlgs[i].tlsid;
1577 return is;
1582 /*================= security functions ======================== */
1584 /* given X.509 certificate, set up connection to factotum */
1585 /* for using corresponding private key */
1586 static AuthRpc*
1587 factotum_rsa_open(uchar *cert, int certlen)
1589 char *s;
1590 mpint *pub = nil;
1591 RSApub *rsapub;
1592 AuthRpc *rpc;
1594 if((rpc = auth_allocrpc()) == nil){
1595 return nil;
1597 s = "proto=rsa service=tls role=client";
1598 if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
1599 factotum_rsa_close(rpc);
1600 return nil;
1603 /* roll factotum keyring around to match certificate */
1604 rsapub = X509toRSApub(cert, certlen, nil, 0);
1605 while(1){
1606 if(auth_rpc(rpc, "read", nil, 0) != ARok){
1607 factotum_rsa_close(rpc);
1608 rpc = nil;
1609 goto done;
1611 pub = strtomp(rpc->arg, nil, 16, nil);
1612 assert(pub != nil);
1613 if(mpcmp(pub,rsapub->n) == 0)
1614 break;
1616 done:
1617 mpfree(pub);
1618 rsapubfree(rsapub);
1619 return rpc;
1622 static mpint*
1623 factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
1625 char *p;
1626 int rv;
1628 if((p = mptoa(cipher, 16, nil, 0)) == nil)
1629 return nil;
1630 rv = auth_rpc(rpc, "write", p, strlen(p));
1631 free(p);
1632 if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
1633 return nil;
1634 mpfree(cipher);
1635 return strtomp(rpc->arg, nil, 16, nil);
1638 static void
1639 factotum_rsa_close(AuthRpc*rpc)
1641 if(!rpc)
1642 return;
1643 close(rpc->afd);
1644 auth_freerpc(rpc);
1647 static void
1648 tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1650 uchar ai[MD5dlen], tmp[MD5dlen];
1651 int i, n;
1652 MD5state *s;
1654 /* generate a1 */
1655 s = hmac_md5(label, nlabel, key, nkey, nil, nil);
1656 s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1657 hmac_md5(seed1, nseed1, key, nkey, ai, s);
1659 while(nbuf > 0) {
1660 s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
1661 s = hmac_md5(label, nlabel, key, nkey, nil, s);
1662 s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1663 hmac_md5(seed1, nseed1, key, nkey, tmp, s);
1664 n = MD5dlen;
1665 if(n > nbuf)
1666 n = nbuf;
1667 for(i = 0; i < n; i++)
1668 buf[i] ^= tmp[i];
1669 buf += n;
1670 nbuf -= n;
1671 hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
1672 memmove(ai, tmp, MD5dlen);
1676 static void
1677 tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1679 uchar ai[SHA1dlen], tmp[SHA1dlen];
1680 int i, n;
1681 SHAstate *s;
1683 /* generate a1 */
1684 s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
1685 s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1686 hmac_sha1(seed1, nseed1, key, nkey, ai, s);
1688 while(nbuf > 0) {
1689 s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
1690 s = hmac_sha1(label, nlabel, key, nkey, nil, s);
1691 s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1692 hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
1693 n = SHA1dlen;
1694 if(n > nbuf)
1695 n = nbuf;
1696 for(i = 0; i < n; i++)
1697 buf[i] ^= tmp[i];
1698 buf += n;
1699 nbuf -= n;
1700 hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
1701 memmove(ai, tmp, SHA1dlen);
1705 /* fill buf with md5(args)^sha1(args) */
1706 static void
1707 tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1709 int i;
1710 int nlabel = strlen(label);
1711 int n = (nkey + 1) >> 1;
1713 for(i = 0; i < nbuf; i++)
1714 buf[i] = 0;
1715 tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1716 tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1720 * for setting server session id's
1722 static Lock sidLock;
1723 static long maxSid = 1;
1725 /* the keys are verified to have the same public components
1726 * and to function correctly with pkcs 1 encryption and decryption. */
1727 static TlsSec*
1728 tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
1730 TlsSec *sec = emalloc(sizeof(*sec));
1732 USED(csid); USED(ncsid); /* ignore csid for now */
1734 memmove(sec->crandom, crandom, RandomSize);
1735 sec->clientVers = cvers;
1737 put32(sec->srandom, time(0));
1738 genrandom(sec->srandom+4, RandomSize-4);
1739 memmove(srandom, sec->srandom, RandomSize);
1742 * make up a unique sid: use our pid, and and incrementing id
1743 * can signal no sid by setting nssid to 0.
1745 memset(ssid, 0, SidSize);
1746 put32(ssid, getpid());
1747 lock(&sidLock);
1748 put32(ssid+4, maxSid++);
1749 unlock(&sidLock);
1750 *nssid = SidSize;
1751 return sec;
1754 static int
1755 tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd)
1757 if(epm != nil){
1758 if(setVers(sec, vers) < 0)
1759 goto Err;
1760 serverMasterSecret(sec, epm, nepm);
1761 }else if(sec->vers != vers){
1762 werrstr("mismatched session versions");
1763 goto Err;
1765 setSecrets(sec, kd, nkd);
1766 return 0;
1767 Err:
1768 sec->ok = -1;
1769 return -1;
1772 static TlsSec*
1773 tlsSecInitc(int cvers, uchar *crandom)
1775 TlsSec *sec = emalloc(sizeof(*sec));
1776 sec->clientVers = cvers;
1777 put32(sec->crandom, time(0));
1778 genrandom(sec->crandom+4, RandomSize-4);
1779 memmove(crandom, sec->crandom, RandomSize);
1780 return sec;
1783 static int
1784 tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd)
1786 RSApub *pub;
1788 pub = nil;
1790 USED(sid);
1791 USED(nsid);
1793 memmove(sec->srandom, srandom, RandomSize);
1795 if(setVers(sec, vers) < 0)
1796 goto Err;
1798 pub = X509toRSApub(cert, ncert, nil, 0);
1799 if(pub == nil){
1800 werrstr("invalid x509/rsa certificate");
1801 goto Err;
1803 if(clientMasterSecret(sec, pub, epm, nepm) < 0)
1804 goto Err;
1805 rsapubfree(pub);
1806 setSecrets(sec, kd, nkd);
1807 return 0;
1809 Err:
1810 if(pub != nil)
1811 rsapubfree(pub);
1812 sec->ok = -1;
1813 return -1;
1816 static int
1817 tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient)
1819 if(sec->nfin != nfin){
1820 sec->ok = -1;
1821 werrstr("invalid finished exchange");
1822 return -1;
1824 md5.malloced = 0;
1825 sha1.malloced = 0;
1826 (*sec->setFinished)(sec, md5, sha1, fin, isclient);
1827 return 1;
1830 static void
1831 tlsSecOk(TlsSec *sec)
1833 if(sec->ok == 0)
1834 sec->ok = 1;
1838 static void
1839 tlsSecKill(TlsSec *sec)
1841 if(!sec)
1842 return;
1843 factotum_rsa_close(sec->rpc);
1844 sec->ok = -1;
1848 static void
1849 tlsSecClose(TlsSec *sec)
1851 if(!sec)
1852 return;
1853 factotum_rsa_close(sec->rpc);
1854 free(sec->server);
1855 free(sec);
1858 static int
1859 setVers(TlsSec *sec, int v)
1861 if(v == SSL3Version){
1862 sec->setFinished = sslSetFinished;
1863 sec->nfin = SSL3FinishedLen;
1864 sec->prf = sslPRF;
1865 }else if(v == TLSVersion){
1866 sec->setFinished = tlsSetFinished;
1867 sec->nfin = TLSFinishedLen;
1868 sec->prf = tlsPRF;
1869 }else{
1870 werrstr("invalid version");
1871 return -1;
1873 sec->vers = v;
1874 return 0;
1878 * generate secret keys from the master secret.
1880 * different crypto selections will require different amounts
1881 * of key expansion and use of key expansion data,
1882 * but it's all generated using the same function.
1884 static void
1885 setSecrets(TlsSec *sec, uchar *kd, int nkd)
1887 (*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
1888 sec->srandom, RandomSize, sec->crandom, RandomSize);
1892 * set the master secret from the pre-master secret.
1894 static void
1895 setMasterSecret(TlsSec *sec, Bytes *pm)
1897 (*sec->prf)(sec->sec, MasterSecretSize, pm->data, MasterSecretSize, "master secret",
1898 sec->crandom, RandomSize, sec->srandom, RandomSize);
1901 static void
1902 serverMasterSecret(TlsSec *sec, uchar *epm, int nepm)
1904 Bytes *pm;
1906 pm = pkcs1_decrypt(sec, epm, nepm);
1908 /* if the client messed up, just continue as if everything is ok, */
1909 /* to prevent attacks to check for correctly formatted messages. */
1910 /* Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client. */
1911 if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
1912 fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
1913 sec->ok, pm, pm ? get16(pm->data) : -1, sec->clientVers, nepm);
1914 sec->ok = -1;
1915 if(pm != nil)
1916 freebytes(pm);
1917 pm = newbytes(MasterSecretSize);
1918 genrandom(pm->data, MasterSecretSize);
1920 setMasterSecret(sec, pm);
1921 memset(pm->data, 0, pm->len);
1922 freebytes(pm);
1925 static int
1926 clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm)
1928 Bytes *pm, *key;
1930 pm = newbytes(MasterSecretSize);
1931 put16(pm->data, sec->clientVers);
1932 genrandom(pm->data+2, MasterSecretSize - 2);
1934 setMasterSecret(sec, pm);
1936 key = pkcs1_encrypt(pm, pub, 2);
1937 memset(pm->data, 0, pm->len);
1938 freebytes(pm);
1939 if(key == nil){
1940 werrstr("tls pkcs1_encrypt failed");
1941 return -1;
1944 *nepm = key->len;
1945 *epm = malloc(*nepm);
1946 if(*epm == nil){
1947 freebytes(key);
1948 werrstr("out of memory");
1949 return -1;
1951 memmove(*epm, key->data, *nepm);
1953 freebytes(key);
1955 return 1;
1958 static void
1959 sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
1961 DigestState *s;
1962 uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
1963 char *label;
1965 if(isClient)
1966 label = "CLNT";
1967 else
1968 label = "SRVR";
1970 md5((uchar*)label, 4, nil, &hsmd5);
1971 md5(sec->sec, MasterSecretSize, nil, &hsmd5);
1972 memset(pad, 0x36, 48);
1973 md5(pad, 48, nil, &hsmd5);
1974 md5(nil, 0, h0, &hsmd5);
1975 memset(pad, 0x5C, 48);
1976 s = md5(sec->sec, MasterSecretSize, nil, nil);
1977 s = md5(pad, 48, nil, s);
1978 md5(h0, MD5dlen, finished, s);
1980 sha1((uchar*)label, 4, nil, &hssha1);
1981 sha1(sec->sec, MasterSecretSize, nil, &hssha1);
1982 memset(pad, 0x36, 40);
1983 sha1(pad, 40, nil, &hssha1);
1984 sha1(nil, 0, h1, &hssha1);
1985 memset(pad, 0x5C, 40);
1986 s = sha1(sec->sec, MasterSecretSize, nil, nil);
1987 s = sha1(pad, 40, nil, s);
1988 sha1(h1, SHA1dlen, finished + MD5dlen, s);
1991 /* fill "finished" arg with md5(args)^sha1(args) */
1992 static void
1993 tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
1995 uchar h0[MD5dlen], h1[SHA1dlen];
1996 char *label;
1998 /* get current hash value, but allow further messages to be hashed in */
1999 md5(nil, 0, h0, &hsmd5);
2000 sha1(nil, 0, h1, &hssha1);
2002 if(isClient)
2003 label = "client finished";
2004 else
2005 label = "server finished";
2006 tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
2009 static void
2010 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2012 DigestState *s;
2013 uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2014 int i, n, len;
2016 USED(label);
2017 len = 1;
2018 while(nbuf > 0){
2019 if(len > 26)
2020 return;
2021 for(i = 0; i < len; i++)
2022 tmp[i] = 'A' - 1 + len;
2023 s = sha1(tmp, len, nil, nil);
2024 s = sha1(key, nkey, nil, s);
2025 s = sha1(seed0, nseed0, nil, s);
2026 sha1(seed1, nseed1, sha1dig, s);
2027 s = md5(key, nkey, nil, nil);
2028 md5(sha1dig, SHA1dlen, md5dig, s);
2029 n = MD5dlen;
2030 if(n > nbuf)
2031 n = nbuf;
2032 memmove(buf, md5dig, n);
2033 buf += n;
2034 nbuf -= n;
2035 len++;
2039 static mpint*
2040 bytestomp(Bytes* bytes)
2042 mpint* ans;
2044 ans = betomp(bytes->data, bytes->len, nil);
2045 return ans;
2049 * Convert mpint* to Bytes, putting high order byte first.
2051 static Bytes*
2052 mptobytes(mpint* big)
2054 int n, m;
2055 uchar *a;
2056 Bytes* ans;
2058 n = (mpsignif(big)+7)/8;
2059 m = mptobe(big, nil, n, &a);
2060 ans = makebytes(a, m);
2061 return ans;
2064 /* Do RSA computation on block according to key, and pad */
2065 /* result on left with zeros to make it modlen long. */
2066 static Bytes*
2067 rsacomp(Bytes* block, RSApub* key, int modlen)
2069 mpint *x, *y;
2070 Bytes *a, *ybytes;
2071 int ylen;
2073 x = bytestomp(block);
2074 y = rsaencrypt(key, x, nil);
2075 mpfree(x);
2076 ybytes = mptobytes(y);
2077 ylen = ybytes->len;
2079 if(ylen < modlen) {
2080 a = newbytes(modlen);
2081 memset(a->data, 0, modlen-ylen);
2082 memmove(a->data+modlen-ylen, ybytes->data, ylen);
2083 freebytes(ybytes);
2084 ybytes = a;
2086 else if(ylen > modlen) {
2087 /* assume it has leading zeros (mod should make it so) */
2088 a = newbytes(modlen);
2089 memmove(a->data, ybytes->data, modlen);
2090 freebytes(ybytes);
2091 ybytes = a;
2093 mpfree(y);
2094 return ybytes;
2097 /* encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1 */
2098 static Bytes*
2099 pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
2101 Bytes *pad, *eb, *ans;
2102 int i, dlen, padlen, modlen;
2104 modlen = (mpsignif(key->n)+7)/8;
2105 dlen = data->len;
2106 if(modlen < 12 || dlen > modlen - 11)
2107 return nil;
2108 padlen = modlen - 3 - dlen;
2109 pad = newbytes(padlen);
2110 genrandom(pad->data, padlen);
2111 for(i = 0; i < padlen; i++) {
2112 if(blocktype == 0)
2113 pad->data[i] = 0;
2114 else if(blocktype == 1)
2115 pad->data[i] = 255;
2116 else if(pad->data[i] == 0)
2117 pad->data[i] = 1;
2119 eb = newbytes(modlen);
2120 eb->data[0] = 0;
2121 eb->data[1] = blocktype;
2122 memmove(eb->data+2, pad->data, padlen);
2123 eb->data[padlen+2] = 0;
2124 memmove(eb->data+padlen+3, data->data, dlen);
2125 ans = rsacomp(eb, key, modlen);
2126 freebytes(eb);
2127 freebytes(pad);
2128 return ans;
2131 /* decrypt data according to PKCS#1, with given key. */
2132 /* expect a block type of 2. */
2133 static Bytes*
2134 pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm)
2136 Bytes *eb, *ans = nil;
2137 int i, modlen;
2138 mpint *x, *y;
2140 modlen = (mpsignif(sec->rsapub->n)+7)/8;
2141 if(nepm != modlen)
2142 return nil;
2143 x = betomp(epm, nepm, nil);
2144 y = factotum_rsa_decrypt(sec->rpc, x);
2145 if(y == nil)
2146 return nil;
2147 eb = mptobytes(y);
2148 if(eb->len < modlen){ /* pad on left with zeros */
2149 ans = newbytes(modlen);
2150 memset(ans->data, 0, modlen-eb->len);
2151 memmove(ans->data+modlen-eb->len, eb->data, eb->len);
2152 freebytes(eb);
2153 eb = ans;
2155 if(eb->data[0] == 0 && eb->data[1] == 2) {
2156 for(i = 2; i < modlen; i++)
2157 if(eb->data[i] == 0)
2158 break;
2159 if(i < modlen - 1)
2160 ans = makebytes(eb->data+i+1, modlen-(i+1));
2162 freebytes(eb);
2163 return ans;
2167 /*================= general utility functions ======================== */
2169 static void *
2170 emalloc(int n)
2172 void *p;
2173 if(n==0)
2174 n=1;
2175 p = malloc(n);
2176 if(p == nil){
2177 exits("out of memory");
2179 memset(p, 0, n);
2180 return p;
2183 static void *
2184 erealloc(void *ReallocP, int ReallocN)
2186 if(ReallocN == 0)
2187 ReallocN = 1;
2188 if(!ReallocP)
2189 ReallocP = emalloc(ReallocN);
2190 else if(!(ReallocP = realloc(ReallocP, ReallocN))){
2191 exits("out of memory");
2193 return(ReallocP);
2196 static void
2197 put32(uchar *p, u32int x)
2199 p[0] = x>>24;
2200 p[1] = x>>16;
2201 p[2] = x>>8;
2202 p[3] = x;
2205 static void
2206 put24(uchar *p, int x)
2208 p[0] = x>>16;
2209 p[1] = x>>8;
2210 p[2] = x;
2213 static void
2214 put16(uchar *p, int x)
2216 p[0] = x>>8;
2217 p[1] = x;
2221 static u32int
2222 get32(uchar *p)
2224 return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2228 static int
2229 get24(uchar *p)
2231 return (p[0]<<16)|(p[1]<<8)|p[2];
2234 static int
2235 get16(uchar *p)
2237 return (p[0]<<8)|p[1];
2240 /* ANSI offsetof() */
2241 #define OFFSET(x, s) ((intptr)(&(((s*)0)->x)))
2244 * malloc and return a new Bytes structure capable of
2245 * holding len bytes. (len >= 0)
2246 * Used to use crypt_malloc, which aborts if malloc fails.
2248 static Bytes*
2249 newbytes(int len)
2251 Bytes* ans;
2253 ans = (Bytes*)malloc(OFFSET(data[0], Bytes) + len);
2254 ans->len = len;
2255 return ans;
2259 * newbytes(len), with data initialized from buf
2261 static Bytes*
2262 makebytes(uchar* buf, int len)
2264 Bytes* ans;
2266 ans = newbytes(len);
2267 memmove(ans->data, buf, len);
2268 return ans;
2271 static void
2272 freebytes(Bytes* b)
2274 if(b != nil)
2275 free(b);
2278 /* len is number of ints */
2279 static Ints*
2280 newints(int len)
2282 Ints* ans;
2284 ans = (Ints*)malloc(OFFSET(data[0], Ints) + len*sizeof(int));
2285 ans->len = len;
2286 return ans;
2290 static Ints*
2291 makeints(int* buf, int len)
2293 Ints* ans;
2295 ans = newints(len);
2296 if(len > 0)
2297 memmove(ans->data, buf, len*sizeof(int));
2298 return ans;
2302 static void
2303 freeints(Ints* b)
2305 if(b != nil)
2306 free(b);