Blob


1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include <auth.h>
5 #include <mp.h>
6 #include <libsec.h>
8 // The main groups of functions are:
9 // client/server - main handshake protocol definition
10 // message functions - formating handshake messages
11 // cipher choices - catalog of digest and encrypt algorithms
12 // security functions - PKCS#1, sslHMAC, session keygen
13 // general utility functions - malloc, serialization
14 // The handshake protocol builds on the TLS/SSL3 record layer protocol,
15 // which is implemented in kernel device #a. See also /lib/rfc/rfc2246.
17 enum {
18 TLSFinishedLen = 12,
19 SSL3FinishedLen = MD5dlen+SHA1dlen,
20 MaxKeyData = 104, // amount of secret we may need
21 MaxChunk = 1<<14,
22 RandomSize = 32,
23 SidSize = 32,
24 MasterSecretSize = 48,
25 AQueue = 0,
26 AFlush = 1,
27 };
29 typedef struct TlsSec TlsSec;
31 typedef struct Bytes{
32 int len;
33 uchar data[1]; // [len]
34 } Bytes;
36 typedef struct Ints{
37 int len;
38 int data[1]; // [len]
39 } Ints;
41 typedef struct Algs{
42 char *enc;
43 char *digest;
44 int nsecret;
45 int tlsid;
46 int ok;
47 } Algs;
49 typedef struct Finished{
50 uchar verify[SSL3FinishedLen];
51 int n;
52 } Finished;
54 typedef struct TlsConnection{
55 TlsSec *sec; // security management goo
56 int hand, ctl; // record layer file descriptors
57 int erred; // set when tlsError called
58 int (*trace)(char*fmt, ...); // for debugging
59 int version; // protocol we are speaking
60 int verset; // version has been set
61 int ver2hi; // server got a version 2 hello
62 int isClient; // is this the client or server?
63 Bytes *sid; // SessionID
64 Bytes *cert; // only last - no chain
66 Lock statelk;
67 int state; // must be set using setstate
69 // input buffer for handshake messages
70 uchar buf[MaxChunk+2048];
71 uchar *rp, *ep;
73 uchar crandom[RandomSize]; // client random
74 uchar srandom[RandomSize]; // server random
75 int clientVersion; // version in ClientHello
76 char *digest; // name of digest algorithm to use
77 char *enc; // name of encryption algorithm to use
78 int nsecret; // amount of secret data to init keys
80 // for finished messages
81 MD5state hsmd5; // handshake hash
82 SHAstate hssha1; // handshake hash
83 Finished finished;
84 } TlsConnection;
86 typedef struct Msg{
87 int tag;
88 union {
89 struct {
90 int version;
91 uchar random[RandomSize];
92 Bytes* sid;
93 Ints* ciphers;
94 Bytes* compressors;
95 } clientHello;
96 struct {
97 int version;
98 uchar random[RandomSize];
99 Bytes* sid;
100 int cipher;
101 int compressor;
102 } serverHello;
103 struct {
104 int ncert;
105 Bytes **certs;
106 } certificate;
107 struct {
108 Bytes *types;
109 int nca;
110 Bytes **cas;
111 } certificateRequest;
112 struct {
113 Bytes *key;
114 } clientKeyExchange;
115 Finished finished;
116 } u;
117 } Msg;
119 struct TlsSec{
120 char *server; // name of remote; nil for server
121 int ok; // <0 killed; ==0 in progress; >0 reusable
122 RSApub *rsapub;
123 AuthRpc *rpc; // factotum for rsa private key
124 uchar sec[MasterSecretSize]; // master secret
125 uchar crandom[RandomSize]; // client random
126 uchar srandom[RandomSize]; // server random
127 int clientVers; // version in ClientHello
128 int vers; // final version
129 // byte generation and handshake checksum
130 void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
131 void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int);
132 int nfin;
133 };
136 enum {
137 TLSVersion = 0x0301,
138 SSL3Version = 0x0300,
139 ProtocolVersion = 0x0301, // maximum version we speak
140 MinProtoVersion = 0x0300, // limits on version we accept
141 MaxProtoVersion = 0x03ff,
142 };
144 // handshake type
145 enum {
146 HHelloRequest,
147 HClientHello,
148 HServerHello,
149 HSSL2ClientHello = 9, /* local convention; see devtls.c */
150 HCertificate = 11,
151 HServerKeyExchange,
152 HCertificateRequest,
153 HServerHelloDone,
154 HCertificateVerify,
155 HClientKeyExchange,
156 HFinished = 20,
157 HMax
158 };
160 // alerts
161 enum {
162 ECloseNotify = 0,
163 EUnexpectedMessage = 10,
164 EBadRecordMac = 20,
165 EDecryptionFailed = 21,
166 ERecordOverflow = 22,
167 EDecompressionFailure = 30,
168 EHandshakeFailure = 40,
169 ENoCertificate = 41,
170 EBadCertificate = 42,
171 EUnsupportedCertificate = 43,
172 ECertificateRevoked = 44,
173 ECertificateExpired = 45,
174 ECertificateUnknown = 46,
175 EIllegalParameter = 47,
176 EUnknownCa = 48,
177 EAccessDenied = 49,
178 EDecodeError = 50,
179 EDecryptError = 51,
180 EExportRestriction = 60,
181 EProtocolVersion = 70,
182 EInsufficientSecurity = 71,
183 EInternalError = 80,
184 EUserCanceled = 90,
185 ENoRenegotiation = 100,
186 EMax = 256
187 };
189 // cipher suites
190 enum {
191 TLS_NULL_WITH_NULL_NULL = 0x0000,
192 TLS_RSA_WITH_NULL_MD5 = 0x0001,
193 TLS_RSA_WITH_NULL_SHA = 0x0002,
194 TLS_RSA_EXPORT_WITH_RC4_40_MD5 = 0x0003,
195 TLS_RSA_WITH_RC4_128_MD5 = 0x0004,
196 TLS_RSA_WITH_RC4_128_SHA = 0x0005,
197 TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 = 0X0006,
198 TLS_RSA_WITH_IDEA_CBC_SHA = 0X0007,
199 TLS_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X0008,
200 TLS_RSA_WITH_DES_CBC_SHA = 0X0009,
201 TLS_RSA_WITH_3DES_EDE_CBC_SHA = 0X000A,
202 TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA = 0X000B,
203 TLS_DH_DSS_WITH_DES_CBC_SHA = 0X000C,
204 TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA = 0X000D,
205 TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X000E,
206 TLS_DH_RSA_WITH_DES_CBC_SHA = 0X000F,
207 TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA = 0X0010,
208 TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA = 0X0011,
209 TLS_DHE_DSS_WITH_DES_CBC_SHA = 0X0012,
210 TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA = 0X0013, // ZZZ must be implemented for tls1.0 compliance
211 TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA = 0X0014,
212 TLS_DHE_RSA_WITH_DES_CBC_SHA = 0X0015,
213 TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA = 0X0016,
214 TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 = 0x0017,
215 TLS_DH_anon_WITH_RC4_128_MD5 = 0x0018,
216 TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA = 0X0019,
217 TLS_DH_anon_WITH_DES_CBC_SHA = 0X001A,
218 TLS_DH_anon_WITH_3DES_EDE_CBC_SHA = 0X001B,
220 TLS_RSA_WITH_AES_128_CBC_SHA = 0X002f, // aes, aka rijndael with 128 bit blocks
221 TLS_DH_DSS_WITH_AES_128_CBC_SHA = 0X0030,
222 TLS_DH_RSA_WITH_AES_128_CBC_SHA = 0X0031,
223 TLS_DHE_DSS_WITH_AES_128_CBC_SHA = 0X0032,
224 TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0X0033,
225 TLS_DH_anon_WITH_AES_128_CBC_SHA = 0X0034,
226 TLS_RSA_WITH_AES_256_CBC_SHA = 0X0035,
227 TLS_DH_DSS_WITH_AES_256_CBC_SHA = 0X0036,
228 TLS_DH_RSA_WITH_AES_256_CBC_SHA = 0X0037,
229 TLS_DHE_DSS_WITH_AES_256_CBC_SHA = 0X0038,
230 TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0X0039,
231 TLS_DH_anon_WITH_AES_256_CBC_SHA = 0X003A,
232 CipherMax
233 };
235 // compression methods
236 enum {
237 CompressionNull = 0,
238 CompressionMax
239 };
241 static Algs cipherAlgs[] = {
242 {"rc4_128", "md5", 2 * (16 + MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
243 {"rc4_128", "sha1", 2 * (16 + SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
244 {"3des_ede_cbc","sha1",2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
245 };
247 static uchar compressors[] = {
248 CompressionNull,
249 };
251 static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...));
252 static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...));
254 static void msgClear(Msg *m);
255 static char* msgPrint(char *buf, int n, Msg *m);
256 static int msgRecv(TlsConnection *c, Msg *m);
257 static int msgSend(TlsConnection *c, Msg *m, int act);
258 static void tlsError(TlsConnection *c, int err, char *msg, ...);
259 /* #pragma varargck argpos tlsError 3*/
260 static int setVersion(TlsConnection *c, int version);
261 static int finishedMatch(TlsConnection *c, Finished *f);
262 static void tlsConnectionFree(TlsConnection *c);
264 static int setAlgs(TlsConnection *c, int a);
265 static int okCipher(Ints *cv);
266 static int okCompression(Bytes *cv);
267 static int initCiphers(void);
268 static Ints* makeciphers(void);
270 static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
271 static int tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd);
272 static TlsSec* tlsSecInitc(int cvers, uchar *crandom);
273 static int tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd);
274 static int tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
275 static void tlsSecOk(TlsSec *sec);
276 static void tlsSecKill(TlsSec *sec);
277 static void tlsSecClose(TlsSec *sec);
278 static void setMasterSecret(TlsSec *sec, Bytes *pm);
279 static void serverMasterSecret(TlsSec *sec, uchar *epm, int nepm);
280 static void setSecrets(TlsSec *sec, uchar *kd, int nkd);
281 static int clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm);
282 static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
283 static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm);
284 static void tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
285 static void sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
286 static void sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
287 uchar *seed0, int nseed0, uchar *seed1, int nseed1);
288 static int setVers(TlsSec *sec, int version);
290 static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
291 static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
292 static void factotum_rsa_close(AuthRpc*rpc);
294 static void* emalloc(int);
295 static void* erealloc(void*, int);
296 static void put32(uchar *p, u32int);
297 static void put24(uchar *p, int);
298 static void put16(uchar *p, int);
299 static u32int get32(uchar *p);
300 static int get24(uchar *p);
301 static int get16(uchar *p);
302 static Bytes* newbytes(int len);
303 static Bytes* makebytes(uchar* buf, int len);
304 static void freebytes(Bytes* b);
305 static Ints* newints(int len);
306 static Ints* makeints(int* buf, int len);
307 static void freeints(Ints* b);
309 //================= client/server ========================
311 // push TLS onto fd, returning new (application) file descriptor
312 // or -1 if error.
313 int
314 tlsServer(int fd, TLSconn *conn)
316 char buf[8];
317 char dname[64];
318 int n, data, ctl, hand;
319 TlsConnection *tls;
321 if(conn == nil)
322 return -1;
323 ctl = open("#a/tls/clone", ORDWR);
324 if(ctl < 0)
325 return -1;
326 n = read(ctl, buf, sizeof(buf)-1);
327 if(n < 0){
328 close(ctl);
329 return -1;
331 buf[n] = 0;
332 sprint(conn->dir, "#a/tls/%s", buf);
333 sprint(dname, "#a/tls/%s/hand", buf);
334 hand = open(dname, ORDWR);
335 if(hand < 0){
336 close(ctl);
337 return -1;
339 fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
340 tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace);
341 sprint(dname, "#a/tls/%s/data", buf);
342 data = open(dname, ORDWR);
343 close(fd);
344 close(hand);
345 close(ctl);
346 if(data < 0){
347 return -1;
349 if(tls == nil){
350 close(data);
351 return -1;
353 if(conn->cert)
354 free(conn->cert);
355 conn->cert = 0; // client certificates are not yet implemented
356 conn->certlen = 0;
357 conn->sessionIDlen = tls->sid->len;
358 conn->sessionID = emalloc(conn->sessionIDlen);
359 memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
360 tlsConnectionFree(tls);
361 return data;
364 // push TLS onto fd, returning new (application) file descriptor
365 // or -1 if error.
366 int
367 tlsClient(int fd, TLSconn *conn)
369 char buf[8];
370 char dname[64];
371 int n, data, ctl, hand;
372 TlsConnection *tls;
374 if(!conn)
375 return -1;
376 ctl = open("#a/tls/clone", ORDWR);
377 if(ctl < 0)
378 return -1;
379 n = read(ctl, buf, sizeof(buf)-1);
380 if(n < 0){
381 close(ctl);
382 return -1;
384 buf[n] = 0;
385 sprint(conn->dir, "#a/tls/%s", buf);
386 sprint(dname, "#a/tls/%s/hand", buf);
387 hand = open(dname, ORDWR);
388 if(hand < 0){
389 close(ctl);
390 return -1;
392 sprint(dname, "#a/tls/%s/data", buf);
393 data = open(dname, ORDWR);
394 if(data < 0)
395 return -1;
396 fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
397 tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->trace);
398 close(fd);
399 close(hand);
400 close(ctl);
401 if(tls == nil){
402 close(data);
403 return -1;
405 conn->certlen = tls->cert->len;
406 conn->cert = emalloc(conn->certlen);
407 memcpy(conn->cert, tls->cert->data, conn->certlen);
408 conn->sessionIDlen = tls->sid->len;
409 conn->sessionID = emalloc(conn->sessionIDlen);
410 memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
411 tlsConnectionFree(tls);
412 return data;
415 static TlsConnection *
416 tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...))
418 TlsConnection *c;
419 Msg m;
420 Bytes *csid;
421 uchar sid[SidSize], kd[MaxKeyData];
422 char *secrets;
423 int cipher, compressor, nsid, rv;
425 if(trace)
426 trace("tlsServer2\n");
427 if(!initCiphers())
428 return nil;
429 c = emalloc(sizeof(TlsConnection));
430 c->ctl = ctl;
431 c->hand = hand;
432 c->trace = trace;
433 c->version = ProtocolVersion;
435 memset(&m, 0, sizeof(m));
436 if(!msgRecv(c, &m)){
437 if(trace)
438 trace("initial msgRecv failed\n");
439 goto Err;
441 if(m.tag != HClientHello) {
442 tlsError(c, EUnexpectedMessage, "expected a client hello");
443 goto Err;
445 c->clientVersion = m.u.clientHello.version;
446 if(trace)
447 trace("ClientHello version %x\n", c->clientVersion);
448 if(setVersion(c, m.u.clientHello.version) < 0) {
449 tlsError(c, EIllegalParameter, "incompatible version");
450 goto Err;
453 memmove(c->crandom, m.u.clientHello.random, RandomSize);
454 cipher = okCipher(m.u.clientHello.ciphers);
455 if(cipher < 0) {
456 // reply with EInsufficientSecurity if we know that's the case
457 if(cipher == -2)
458 tlsError(c, EInsufficientSecurity, "cipher suites too weak");
459 else
460 tlsError(c, EHandshakeFailure, "no matching cipher suite");
461 goto Err;
463 if(!setAlgs(c, cipher)){
464 tlsError(c, EHandshakeFailure, "no matching cipher suite");
465 goto Err;
467 compressor = okCompression(m.u.clientHello.compressors);
468 if(compressor < 0) {
469 tlsError(c, EHandshakeFailure, "no matching compressor");
470 goto Err;
473 csid = m.u.clientHello.sid;
474 if(trace)
475 trace(" cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
476 c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
477 if(c->sec == nil){
478 tlsError(c, EHandshakeFailure, "can't initialize security: %r");
479 goto Err;
481 c->sec->rpc = factotum_rsa_open(cert, ncert);
482 if(c->sec->rpc == nil){
483 tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
484 goto Err;
486 c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0);
487 msgClear(&m);
489 m.tag = HServerHello;
490 m.u.serverHello.version = c->version;
491 memmove(m.u.serverHello.random, c->srandom, RandomSize);
492 m.u.serverHello.cipher = cipher;
493 m.u.serverHello.compressor = compressor;
494 c->sid = makebytes(sid, nsid);
495 m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
496 if(!msgSend(c, &m, AQueue))
497 goto Err;
498 msgClear(&m);
500 m.tag = HCertificate;
501 m.u.certificate.ncert = 1;
502 m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes));
503 m.u.certificate.certs[0] = makebytes(cert, ncert);
504 if(!msgSend(c, &m, AQueue))
505 goto Err;
506 msgClear(&m);
508 m.tag = HServerHelloDone;
509 if(!msgSend(c, &m, AFlush))
510 goto Err;
511 msgClear(&m);
513 if(!msgRecv(c, &m))
514 goto Err;
515 if(m.tag != HClientKeyExchange) {
516 tlsError(c, EUnexpectedMessage, "expected a client key exchange");
517 goto Err;
519 if(tlsSecSecrets(c->sec, c->version, m.u.clientKeyExchange.key->data, m.u.clientKeyExchange.key->len, kd, c->nsecret) < 0){
520 tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
521 goto Err;
523 if(trace)
524 trace("tls secrets\n");
525 secrets = (char*)emalloc(2*c->nsecret);
526 enc64(secrets, 2*c->nsecret, kd, c->nsecret);
527 rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
528 memset(secrets, 0, 2*c->nsecret);
529 free(secrets);
530 memset(kd, 0, c->nsecret);
531 if(rv < 0){
532 tlsError(c, EHandshakeFailure, "can't set keys: %r");
533 goto Err;
535 msgClear(&m);
537 /* no CertificateVerify; skip to Finished */
538 if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
539 tlsError(c, EInternalError, "can't set finished: %r");
540 goto Err;
542 if(!msgRecv(c, &m))
543 goto Err;
544 if(m.tag != HFinished) {
545 tlsError(c, EUnexpectedMessage, "expected a finished");
546 goto Err;
548 if(!finishedMatch(c, &m.u.finished)) {
549 tlsError(c, EHandshakeFailure, "finished verification failed");
550 goto Err;
552 msgClear(&m);
554 /* change cipher spec */
555 if(fprint(c->ctl, "changecipher") < 0){
556 tlsError(c, EInternalError, "can't enable cipher: %r");
557 goto Err;
560 if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
561 tlsError(c, EInternalError, "can't set finished: %r");
562 goto Err;
564 m.tag = HFinished;
565 m.u.finished = c->finished;
566 if(!msgSend(c, &m, AFlush))
567 goto Err;
568 msgClear(&m);
569 if(trace)
570 trace("tls finished\n");
572 if(fprint(c->ctl, "opened") < 0)
573 goto Err;
574 tlsSecOk(c->sec);
575 return c;
577 Err:
578 msgClear(&m);
579 tlsConnectionFree(c);
580 return 0;
583 static TlsConnection *
584 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...))
586 TlsConnection *c;
587 Msg m;
588 uchar kd[MaxKeyData], *epm;
589 char *secrets;
590 int creq, nepm, rv;
592 if(!initCiphers())
593 return nil;
594 epm = nil;
595 c = emalloc(sizeof(TlsConnection));
596 c->version = ProtocolVersion;
597 c->ctl = ctl;
598 c->hand = hand;
599 c->trace = trace;
600 c->isClient = 1;
601 c->clientVersion = c->version;
603 c->sec = tlsSecInitc(c->clientVersion, c->crandom);
604 if(c->sec == nil)
605 goto Err;
607 /* client hello */
608 memset(&m, 0, sizeof(m));
609 m.tag = HClientHello;
610 m.u.clientHello.version = c->clientVersion;
611 memmove(m.u.clientHello.random, c->crandom, RandomSize);
612 m.u.clientHello.sid = makebytes(csid, ncsid);
613 m.u.clientHello.ciphers = makeciphers();
614 m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
615 if(!msgSend(c, &m, AFlush))
616 goto Err;
617 msgClear(&m);
619 /* server hello */
620 if(!msgRecv(c, &m))
621 goto Err;
622 if(m.tag != HServerHello) {
623 tlsError(c, EUnexpectedMessage, "expected a server hello");
624 goto Err;
626 if(setVersion(c, m.u.serverHello.version) < 0) {
627 tlsError(c, EIllegalParameter, "incompatible version %r");
628 goto Err;
630 memmove(c->srandom, m.u.serverHello.random, RandomSize);
631 c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
632 if(c->sid->len != 0 && c->sid->len != SidSize) {
633 tlsError(c, EIllegalParameter, "invalid server session identifier");
634 goto Err;
636 if(!setAlgs(c, m.u.serverHello.cipher)) {
637 tlsError(c, EIllegalParameter, "invalid cipher suite");
638 goto Err;
640 if(m.u.serverHello.compressor != CompressionNull) {
641 tlsError(c, EIllegalParameter, "invalid compression");
642 goto Err;
644 msgClear(&m);
646 /* certificate */
647 if(!msgRecv(c, &m) || m.tag != HCertificate) {
648 tlsError(c, EUnexpectedMessage, "expected a certificate");
649 goto Err;
651 if(m.u.certificate.ncert < 1) {
652 tlsError(c, EIllegalParameter, "runt certificate");
653 goto Err;
655 c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
656 msgClear(&m);
658 /* server key exchange (optional) */
659 if(!msgRecv(c, &m))
660 goto Err;
661 if(m.tag == HServerKeyExchange) {
662 tlsError(c, EUnexpectedMessage, "got an server key exchange");
663 goto Err;
664 // If implementing this later, watch out for rollback attack
665 // described in Wagner Schneier 1996, section 4.4.
668 /* certificate request (optional) */
669 creq = 0;
670 if(m.tag == HCertificateRequest) {
671 creq = 1;
672 msgClear(&m);
673 if(!msgRecv(c, &m))
674 goto Err;
677 if(m.tag != HServerHelloDone) {
678 tlsError(c, EUnexpectedMessage, "expected a server hello done");
679 goto Err;
681 msgClear(&m);
683 if(tlsSecSecretc(c->sec, c->sid->data, c->sid->len, c->srandom,
684 c->cert->data, c->cert->len, c->version, &epm, &nepm,
685 kd, c->nsecret) < 0){
686 tlsError(c, EBadCertificate, "invalid x509/rsa certificate");
687 goto Err;
689 secrets = (char*)emalloc(2*c->nsecret);
690 enc64(secrets, 2*c->nsecret, kd, c->nsecret);
691 rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
692 memset(secrets, 0, 2*c->nsecret);
693 free(secrets);
694 memset(kd, 0, c->nsecret);
695 if(rv < 0){
696 tlsError(c, EHandshakeFailure, "can't set keys: %r");
697 goto Err;
700 if(creq) {
701 /* send a zero length certificate */
702 m.tag = HCertificate;
703 if(!msgSend(c, &m, AFlush))
704 goto Err;
705 msgClear(&m);
708 /* client key exchange */
709 m.tag = HClientKeyExchange;
710 m.u.clientKeyExchange.key = makebytes(epm, nepm);
711 free(epm);
712 epm = nil;
713 if(m.u.clientKeyExchange.key == nil) {
714 tlsError(c, EHandshakeFailure, "can't set secret: %r");
715 goto Err;
717 if(!msgSend(c, &m, AFlush))
718 goto Err;
719 msgClear(&m);
721 /* change cipher spec */
722 if(fprint(c->ctl, "changecipher") < 0){
723 tlsError(c, EInternalError, "can't enable cipher: %r");
724 goto Err;
727 // Cipherchange must occur immediately before Finished to avoid
728 // potential hole; see section 4.3 of Wagner Schneier 1996.
729 if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
730 tlsError(c, EInternalError, "can't set finished 1: %r");
731 goto Err;
733 m.tag = HFinished;
734 m.u.finished = c->finished;
736 if(!msgSend(c, &m, AFlush)) {
737 fprint(2, "tlsClient nepm=%d\n", nepm);
738 tlsError(c, EInternalError, "can't flush after client Finished: %r");
739 goto Err;
741 msgClear(&m);
743 if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
744 fprint(2, "tlsClient nepm=%d\n", nepm);
745 tlsError(c, EInternalError, "can't set finished 0: %r");
746 goto Err;
748 if(!msgRecv(c, &m)) {
749 fprint(2, "tlsClient nepm=%d\n", nepm);
750 tlsError(c, EInternalError, "can't read server Finished: %r");
751 goto Err;
753 if(m.tag != HFinished) {
754 fprint(2, "tlsClient nepm=%d\n", nepm);
755 tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
756 goto Err;
759 if(!finishedMatch(c, &m.u.finished)) {
760 tlsError(c, EHandshakeFailure, "finished verification failed");
761 goto Err;
763 msgClear(&m);
765 if(fprint(c->ctl, "opened") < 0){
766 if(trace)
767 trace("unable to do final open: %r\n");
768 goto Err;
770 tlsSecOk(c->sec);
771 return c;
773 Err:
774 free(epm);
775 msgClear(&m);
776 tlsConnectionFree(c);
777 return 0;
781 //================= message functions ========================
783 static uchar sendbuf[9000], *sendp;
785 static int
786 msgSend(TlsConnection *c, Msg *m, int act)
788 uchar *p; // sendp = start of new message; p = write pointer
789 int nn, n, i;
791 if(sendp == nil)
792 sendp = sendbuf;
793 p = sendp;
794 if(c->trace)
795 c->trace("send %s", msgPrint((char*)p, (sizeof sendbuf) - (p-sendbuf), m));
797 p[0] = m->tag; // header - fill in size later
798 p += 4;
800 switch(m->tag) {
801 default:
802 tlsError(c, EInternalError, "can't encode a %d", m->tag);
803 goto Err;
804 case HClientHello:
805 // version
806 put16(p, m->u.clientHello.version);
807 p += 2;
809 // random
810 memmove(p, m->u.clientHello.random, RandomSize);
811 p += RandomSize;
813 // sid
814 n = m->u.clientHello.sid->len;
815 assert(n < 256);
816 p[0] = n;
817 memmove(p+1, m->u.clientHello.sid->data, n);
818 p += n+1;
820 n = m->u.clientHello.ciphers->len;
821 assert(n > 0 && n < 200);
822 put16(p, n*2);
823 p += 2;
824 for(i=0; i<n; i++) {
825 put16(p, m->u.clientHello.ciphers->data[i]);
826 p += 2;
829 n = m->u.clientHello.compressors->len;
830 assert(n > 0);
831 p[0] = n;
832 memmove(p+1, m->u.clientHello.compressors->data, n);
833 p += n+1;
834 break;
835 case HServerHello:
836 put16(p, m->u.serverHello.version);
837 p += 2;
839 // random
840 memmove(p, m->u.serverHello.random, RandomSize);
841 p += RandomSize;
843 // sid
844 n = m->u.serverHello.sid->len;
845 assert(n < 256);
846 p[0] = n;
847 memmove(p+1, m->u.serverHello.sid->data, n);
848 p += n+1;
850 put16(p, m->u.serverHello.cipher);
851 p += 2;
852 p[0] = m->u.serverHello.compressor;
853 p += 1;
854 break;
855 case HServerHelloDone:
856 break;
857 case HCertificate:
858 nn = 0;
859 for(i = 0; i < m->u.certificate.ncert; i++)
860 nn += 3 + m->u.certificate.certs[i]->len;
861 if(p + 3 + nn - sendbuf > sizeof(sendbuf)) {
862 tlsError(c, EInternalError, "output buffer too small for certificate");
863 goto Err;
865 put24(p, nn);
866 p += 3;
867 for(i = 0; i < m->u.certificate.ncert; i++){
868 put24(p, m->u.certificate.certs[i]->len);
869 p += 3;
870 memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
871 p += m->u.certificate.certs[i]->len;
873 break;
874 case HClientKeyExchange:
875 n = m->u.clientKeyExchange.key->len;
876 if(c->version != SSL3Version){
877 put16(p, n);
878 p += 2;
880 memmove(p, m->u.clientKeyExchange.key->data, n);
881 p += n;
882 break;
883 case HFinished:
884 memmove(p, m->u.finished.verify, m->u.finished.n);
885 p += m->u.finished.n;
886 break;
889 // go back and fill in size
890 n = p-sendp;
891 assert(p <= sendbuf+sizeof(sendbuf));
892 put24(sendp+1, n-4);
894 // remember hash of Handshake messages
895 if(m->tag != HHelloRequest) {
896 md5(sendp, n, 0, &c->hsmd5);
897 sha1(sendp, n, 0, &c->hssha1);
900 sendp = p;
901 if(act == AFlush){
902 sendp = sendbuf;
903 if(write(c->hand, sendbuf, p-sendbuf) < 0){
904 fprint(2, "write error: %r\n");
905 goto Err;
908 msgClear(m);
909 return 1;
910 Err:
911 msgClear(m);
912 return 0;
915 static uchar*
916 tlsReadN(TlsConnection *c, int n)
918 uchar *p;
919 int nn, nr;
921 nn = c->ep - c->rp;
922 if(nn < n){
923 if(c->rp != c->buf){
924 memmove(c->buf, c->rp, nn);
925 c->rp = c->buf;
926 c->ep = &c->buf[nn];
928 for(; nn < n; nn += nr) {
929 nr = read(c->hand, &c->rp[nn], n - nn);
930 if(nr <= 0)
931 return nil;
932 c->ep += nr;
935 p = c->rp;
936 c->rp += n;
937 return p;
940 static int
941 msgRecv(TlsConnection *c, Msg *m)
943 uchar *p;
944 int type, n, nn, i, nsid, nrandom, nciph;
946 for(;;) {
947 p = tlsReadN(c, 4);
948 if(p == nil)
949 return 0;
950 type = p[0];
951 n = get24(p+1);
953 if(type != HHelloRequest)
954 break;
955 if(n != 0) {
956 tlsError(c, EDecodeError, "invalid hello request during handshake");
957 return 0;
961 if(n > sizeof(c->buf)) {
962 tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->buf));
963 return 0;
966 if(type == HSSL2ClientHello){
967 /* Cope with an SSL3 ClientHello expressed in SSL2 record format.
968 This is sent by some clients that we must interoperate
969 with, such as Java's JSSE and Microsoft's Internet Explorer. */
970 p = tlsReadN(c, n);
971 if(p == nil)
972 return 0;
973 md5(p, n, 0, &c->hsmd5);
974 sha1(p, n, 0, &c->hssha1);
975 m->tag = HClientHello;
976 if(n < 22)
977 goto Short;
978 m->u.clientHello.version = get16(p+1);
979 p += 3;
980 n -= 3;
981 nn = get16(p); /* cipher_spec_len */
982 nsid = get16(p + 2);
983 nrandom = get16(p + 4);
984 p += 6;
985 n -= 6;
986 if(nsid != 0 /* no sid's, since shouldn't restart using ssl2 header */
987 || nrandom < 16 || nn % 3)
988 goto Err;
989 if(c->trace && (n - nrandom != nn))
990 c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
991 /* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
992 nciph = 0;
993 for(i = 0; i < nn; i += 3)
994 if(p[i] == 0)
995 nciph++;
996 m->u.clientHello.ciphers = newints(nciph);
997 nciph = 0;
998 for(i = 0; i < nn; i += 3)
999 if(p[i] == 0)
1000 m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1001 p += nn;
1002 m->u.clientHello.sid = makebytes(nil, 0);
1003 if(nrandom > RandomSize)
1004 nrandom = RandomSize;
1005 memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1006 memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1007 m->u.clientHello.compressors = newbytes(1);
1008 m->u.clientHello.compressors->data[0] = CompressionNull;
1009 goto Ok;
1012 md5(p, 4, 0, &c->hsmd5);
1013 sha1(p, 4, 0, &c->hssha1);
1015 p = tlsReadN(c, n);
1016 if(p == nil)
1017 return 0;
1019 md5(p, n, 0, &c->hsmd5);
1020 sha1(p, n, 0, &c->hssha1);
1022 m->tag = type;
1024 switch(type) {
1025 default:
1026 tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1027 goto Err;
1028 case HClientHello:
1029 if(n < 2)
1030 goto Short;
1031 m->u.clientHello.version = get16(p);
1032 p += 2;
1033 n -= 2;
1035 if(n < RandomSize)
1036 goto Short;
1037 memmove(m->u.clientHello.random, p, RandomSize);
1038 p += RandomSize;
1039 n -= RandomSize;
1040 if(n < 1 || n < p[0]+1)
1041 goto Short;
1042 m->u.clientHello.sid = makebytes(p+1, p[0]);
1043 p += m->u.clientHello.sid->len+1;
1044 n -= m->u.clientHello.sid->len+1;
1046 if(n < 2)
1047 goto Short;
1048 nn = get16(p);
1049 p += 2;
1050 n -= 2;
1052 if((nn & 1) || n < nn || nn < 2)
1053 goto Short;
1054 m->u.clientHello.ciphers = newints(nn >> 1);
1055 for(i = 0; i < nn; i += 2)
1056 m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
1057 p += nn;
1058 n -= nn;
1060 if(n < 1 || n < p[0]+1 || p[0] == 0)
1061 goto Short;
1062 nn = p[0];
1063 m->u.clientHello.compressors = newbytes(nn);
1064 memmove(m->u.clientHello.compressors->data, p+1, nn);
1065 n -= nn + 1;
1066 break;
1067 case HServerHello:
1068 if(n < 2)
1069 goto Short;
1070 m->u.serverHello.version = get16(p);
1071 p += 2;
1072 n -= 2;
1074 if(n < RandomSize)
1075 goto Short;
1076 memmove(m->u.serverHello.random, p, RandomSize);
1077 p += RandomSize;
1078 n -= RandomSize;
1080 if(n < 1 || n < p[0]+1)
1081 goto Short;
1082 m->u.serverHello.sid = makebytes(p+1, p[0]);
1083 p += m->u.serverHello.sid->len+1;
1084 n -= m->u.serverHello.sid->len+1;
1086 if(n < 3)
1087 goto Short;
1088 m->u.serverHello.cipher = get16(p);
1089 m->u.serverHello.compressor = p[2];
1090 n -= 3;
1091 break;
1092 case HCertificate:
1093 if(n < 3)
1094 goto Short;
1095 nn = get24(p);
1096 p += 3;
1097 n -= 3;
1098 if(n != nn)
1099 goto Short;
1100 /* certs */
1101 i = 0;
1102 while(n > 0) {
1103 if(n < 3)
1104 goto Short;
1105 nn = get24(p);
1106 p += 3;
1107 n -= 3;
1108 if(nn > n)
1109 goto Short;
1110 m->u.certificate.ncert = i+1;
1111 m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes));
1112 m->u.certificate.certs[i] = makebytes(p, nn);
1113 p += nn;
1114 n -= nn;
1115 i++;
1117 break;
1118 case HCertificateRequest:
1119 if(n < 2)
1120 goto Short;
1121 nn = get16(p);
1122 p += 2;
1123 n -= 2;
1124 if(nn < 1 || nn > n)
1125 goto Short;
1126 m->u.certificateRequest.types = makebytes(p, nn);
1127 nn = get24(p);
1128 p += 3;
1129 n -= 3;
1130 if(nn == 0 || n != nn)
1131 goto Short;
1132 /* cas */
1133 i = 0;
1134 while(n > 0) {
1135 if(n < 2)
1136 goto Short;
1137 nn = get16(p);
1138 p += 2;
1139 n -= 2;
1140 if(nn < 1 || nn > n)
1141 goto Short;
1142 m->u.certificateRequest.nca = i+1;
1143 m->u.certificateRequest.cas = erealloc(m->u.certificateRequest.cas, (i+1)*sizeof(Bytes));
1144 m->u.certificateRequest.cas[i] = makebytes(p, nn);
1145 p += nn;
1146 n -= nn;
1147 i++;
1149 break;
1150 case HServerHelloDone:
1151 break;
1152 case HClientKeyExchange:
1154 * this message depends upon the encryption selected
1155 * assume rsa.
1157 if(c->version == SSL3Version)
1158 nn = n;
1159 else{
1160 if(n < 2)
1161 goto Short;
1162 nn = get16(p);
1163 p += 2;
1164 n -= 2;
1166 if(n < nn)
1167 goto Short;
1168 m->u.clientKeyExchange.key = makebytes(p, nn);
1169 n -= nn;
1170 break;
1171 case HFinished:
1172 m->u.finished.n = c->finished.n;
1173 if(n < m->u.finished.n)
1174 goto Short;
1175 memmove(m->u.finished.verify, p, m->u.finished.n);
1176 n -= m->u.finished.n;
1177 break;
1180 if(type != HClientHello && n != 0)
1181 goto Short;
1182 Ok:
1183 if(c->trace){
1184 char buf[8000];
1185 c->trace("recv %s", msgPrint(buf, sizeof buf, m));
1187 return 1;
1188 Short:
1189 tlsError(c, EDecodeError, "handshake message has invalid length");
1190 Err:
1191 msgClear(m);
1192 return 0;
1195 static void
1196 msgClear(Msg *m)
1198 int i;
1200 switch(m->tag) {
1201 default:
1202 sysfatal("msgClear: unknown message type: %d\n", m->tag);
1203 case HHelloRequest:
1204 break;
1205 case HClientHello:
1206 freebytes(m->u.clientHello.sid);
1207 freeints(m->u.clientHello.ciphers);
1208 freebytes(m->u.clientHello.compressors);
1209 break;
1210 case HServerHello:
1211 freebytes(m->u.clientHello.sid);
1212 break;
1213 case HCertificate:
1214 for(i=0; i<m->u.certificate.ncert; i++)
1215 freebytes(m->u.certificate.certs[i]);
1216 free(m->u.certificate.certs);
1217 break;
1218 case HCertificateRequest:
1219 freebytes(m->u.certificateRequest.types);
1220 for(i=0; i<m->u.certificateRequest.nca; i++)
1221 freebytes(m->u.certificateRequest.cas[i]);
1222 free(m->u.certificateRequest.cas);
1223 break;
1224 case HServerHelloDone:
1225 break;
1226 case HClientKeyExchange:
1227 freebytes(m->u.clientKeyExchange.key);
1228 break;
1229 case HFinished:
1230 break;
1232 memset(m, 0, sizeof(Msg));
1235 static char *
1236 bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1238 int i;
1240 if(s0)
1241 bs = seprint(bs, be, "%s", s0);
1242 bs = seprint(bs, be, "[");
1243 if(b == nil)
1244 bs = seprint(bs, be, "nil");
1245 else
1246 for(i=0; i<b->len; i++)
1247 bs = seprint(bs, be, "%.2x ", b->data[i]);
1248 bs = seprint(bs, be, "]");
1249 if(s1)
1250 bs = seprint(bs, be, "%s", s1);
1251 return bs;
1254 static char *
1255 intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1257 int i;
1259 if(s0)
1260 bs = seprint(bs, be, "%s", s0);
1261 bs = seprint(bs, be, "[");
1262 if(b == nil)
1263 bs = seprint(bs, be, "nil");
1264 else
1265 for(i=0; i<b->len; i++)
1266 bs = seprint(bs, be, "%x ", b->data[i]);
1267 bs = seprint(bs, be, "]");
1268 if(s1)
1269 bs = seprint(bs, be, "%s", s1);
1270 return bs;
1273 static char*
1274 msgPrint(char *buf, int n, Msg *m)
1276 int i;
1277 char *bs = buf, *be = buf+n;
1279 switch(m->tag) {
1280 default:
1281 bs = seprint(bs, be, "unknown %d\n", m->tag);
1282 break;
1283 case HClientHello:
1284 bs = seprint(bs, be, "ClientHello\n");
1285 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1286 bs = seprint(bs, be, "\trandom: ");
1287 for(i=0; i<RandomSize; i++)
1288 bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1289 bs = seprint(bs, be, "\n");
1290 bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1291 bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1292 bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
1293 break;
1294 case HServerHello:
1295 bs = seprint(bs, be, "ServerHello\n");
1296 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1297 bs = seprint(bs, be, "\trandom: ");
1298 for(i=0; i<RandomSize; i++)
1299 bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1300 bs = seprint(bs, be, "\n");
1301 bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1302 bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1303 bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
1304 break;
1305 case HCertificate:
1306 bs = seprint(bs, be, "Certificate\n");
1307 for(i=0; i<m->u.certificate.ncert; i++)
1308 bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1309 break;
1310 case HCertificateRequest:
1311 bs = seprint(bs, be, "CertificateRequest\n");
1312 bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
1313 bs = seprint(bs, be, "\tcertificateauthorities\n");
1314 for(i=0; i<m->u.certificateRequest.nca; i++)
1315 bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1316 break;
1317 case HServerHelloDone:
1318 bs = seprint(bs, be, "ServerHelloDone\n");
1319 break;
1320 case HClientKeyExchange:
1321 bs = seprint(bs, be, "HClientKeyExchange\n");
1322 bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
1323 break;
1324 case HFinished:
1325 bs = seprint(bs, be, "HFinished\n");
1326 for(i=0; i<m->u.finished.n; i++)
1327 bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1328 bs = seprint(bs, be, "\n");
1329 break;
1331 USED(bs);
1332 return buf;
1335 static void
1336 tlsError(TlsConnection *c, int err, char *fmt, ...)
1338 char msg[512];
1339 va_list arg;
1341 va_start(arg, fmt);
1342 vseprint(msg, msg+sizeof(msg), fmt, arg);
1343 va_end(arg);
1344 if(c->trace)
1345 c->trace("tlsError: %s\n", msg);
1346 else if(c->erred)
1347 fprint(2, "double error: %r, %s", msg);
1348 else
1349 werrstr("tls: local %s", msg);
1350 c->erred = 1;
1351 fprint(c->ctl, "alert %d", err);
1354 // commit to specific version number
1355 static int
1356 setVersion(TlsConnection *c, int version)
1358 if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
1359 return -1;
1360 if(version > c->version)
1361 version = c->version;
1362 if(version == SSL3Version) {
1363 c->version = version;
1364 c->finished.n = SSL3FinishedLen;
1365 }else if(version == TLSVersion){
1366 c->version = version;
1367 c->finished.n = TLSFinishedLen;
1368 }else
1369 return -1;
1370 c->verset = 1;
1371 return fprint(c->ctl, "version 0x%x", version);
1374 // confirm that received Finished message matches the expected value
1375 static int
1376 finishedMatch(TlsConnection *c, Finished *f)
1378 return memcmp(f->verify, c->finished.verify, f->n) == 0;
1381 // free memory associated with TlsConnection struct
1382 // (but don't close the TLS channel itself)
1383 static void
1384 tlsConnectionFree(TlsConnection *c)
1386 tlsSecClose(c->sec);
1387 freebytes(c->sid);
1388 freebytes(c->cert);
1389 memset(c, 0, sizeof(c));
1390 free(c);
1394 //================= cipher choices ========================
1396 static int weakCipher[CipherMax] =
1398 1, /* TLS_NULL_WITH_NULL_NULL */
1399 1, /* TLS_RSA_WITH_NULL_MD5 */
1400 1, /* TLS_RSA_WITH_NULL_SHA */
1401 1, /* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
1402 0, /* TLS_RSA_WITH_RC4_128_MD5 */
1403 0, /* TLS_RSA_WITH_RC4_128_SHA */
1404 1, /* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
1405 0, /* TLS_RSA_WITH_IDEA_CBC_SHA */
1406 1, /* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
1407 0, /* TLS_RSA_WITH_DES_CBC_SHA */
1408 0, /* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
1409 1, /* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
1410 0, /* TLS_DH_DSS_WITH_DES_CBC_SHA */
1411 0, /* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
1412 1, /* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
1413 0, /* TLS_DH_RSA_WITH_DES_CBC_SHA */
1414 0, /* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
1415 1, /* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
1416 0, /* TLS_DHE_DSS_WITH_DES_CBC_SHA */
1417 0, /* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
1418 1, /* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
1419 0, /* TLS_DHE_RSA_WITH_DES_CBC_SHA */
1420 0, /* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
1421 1, /* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
1422 1, /* TLS_DH_anon_WITH_RC4_128_MD5 */
1423 1, /* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
1424 1, /* TLS_DH_anon_WITH_DES_CBC_SHA */
1425 1, /* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
1428 static int
1429 setAlgs(TlsConnection *c, int a)
1431 int i;
1433 for(i = 0; i < nelem(cipherAlgs); i++){
1434 if(cipherAlgs[i].tlsid == a){
1435 c->enc = cipherAlgs[i].enc;
1436 c->digest = cipherAlgs[i].digest;
1437 c->nsecret = cipherAlgs[i].nsecret;
1438 if(c->nsecret > MaxKeyData)
1439 return 0;
1440 return 1;
1443 return 0;
1446 static int
1447 okCipher(Ints *cv)
1449 int weak, i, j, c;
1451 weak = 1;
1452 for(i = 0; i < cv->len; i++) {
1453 c = cv->data[i];
1454 if(c >= CipherMax)
1455 weak = 0;
1456 else
1457 weak &= weakCipher[c];
1458 for(j = 0; j < nelem(cipherAlgs); j++)
1459 if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
1460 return c;
1462 if(weak)
1463 return -2;
1464 return -1;
1467 static int
1468 okCompression(Bytes *cv)
1470 int i, j, c;
1472 for(i = 0; i < cv->len; i++) {
1473 c = cv->data[i];
1474 for(j = 0; j < nelem(compressors); j++) {
1475 if(compressors[j] == c)
1476 return c;
1479 return -1;
1482 static Lock ciphLock;
1483 static int nciphers;
1485 static int
1486 initCiphers(void)
1488 enum {MaxAlgF = 1024, MaxAlgs = 10};
1489 char s[MaxAlgF], *flds[MaxAlgs];
1490 int i, j, n, ok;
1492 lock(&ciphLock);
1493 if(nciphers){
1494 unlock(&ciphLock);
1495 return nciphers;
1497 j = open("#a/tls/encalgs", OREAD);
1498 if(j < 0){
1499 werrstr("can't open #a/tls/encalgs: %r");
1500 return 0;
1502 n = read(j, s, MaxAlgF-1);
1503 close(j);
1504 if(n <= 0){
1505 werrstr("nothing in #a/tls/encalgs: %r");
1506 return 0;
1508 s[n] = 0;
1509 n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1510 for(i = 0; i < nelem(cipherAlgs); i++){
1511 ok = 0;
1512 for(j = 0; j < n; j++){
1513 if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
1514 ok = 1;
1515 break;
1518 cipherAlgs[i].ok = ok;
1521 j = open("#a/tls/hashalgs", OREAD);
1522 if(j < 0){
1523 werrstr("can't open #a/tls/hashalgs: %r");
1524 return 0;
1526 n = read(j, s, MaxAlgF-1);
1527 close(j);
1528 if(n <= 0){
1529 werrstr("nothing in #a/tls/hashalgs: %r");
1530 return 0;
1532 s[n] = 0;
1533 n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
1534 for(i = 0; i < nelem(cipherAlgs); i++){
1535 ok = 0;
1536 for(j = 0; j < n; j++){
1537 if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
1538 ok = 1;
1539 break;
1542 cipherAlgs[i].ok &= ok;
1543 if(cipherAlgs[i].ok)
1544 nciphers++;
1546 unlock(&ciphLock);
1547 return nciphers;
1550 static Ints*
1551 makeciphers(void)
1553 Ints *is;
1554 int i, j;
1556 is = newints(nciphers);
1557 j = 0;
1558 for(i = 0; i < nelem(cipherAlgs); i++){
1559 if(cipherAlgs[i].ok)
1560 is->data[j++] = cipherAlgs[i].tlsid;
1562 return is;
1567 //================= security functions ========================
1569 // given X.509 certificate, set up connection to factotum
1570 // for using corresponding private key
1571 static AuthRpc*
1572 factotum_rsa_open(uchar *cert, int certlen)
1574 int afd;
1575 char *s;
1576 mpint *pub = nil;
1577 RSApub *rsapub;
1578 AuthRpc *rpc;
1580 // start talking to factotum
1581 if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
1582 return nil;
1583 if((rpc = auth_allocrpc(afd)) == nil){
1584 close(afd);
1585 return nil;
1587 s = "proto=rsa service=tls role=client";
1588 if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
1589 factotum_rsa_close(rpc);
1590 return nil;
1593 // roll factotum keyring around to match certificate
1594 rsapub = X509toRSApub(cert, certlen, nil, 0);
1595 while(1){
1596 if(auth_rpc(rpc, "read", nil, 0) != ARok){
1597 factotum_rsa_close(rpc);
1598 rpc = nil;
1599 goto done;
1601 pub = strtomp(rpc->arg, nil, 16, nil);
1602 assert(pub != nil);
1603 if(mpcmp(pub,rsapub->n) == 0)
1604 break;
1606 done:
1607 mpfree(pub);
1608 rsapubfree(rsapub);
1609 return rpc;
1612 static mpint*
1613 factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
1615 char *p;
1616 int rv;
1618 if((p = mptoa(cipher, 16, nil, 0)) == nil)
1619 return nil;
1620 rv = auth_rpc(rpc, "write", p, strlen(p));
1621 free(p);
1622 if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
1623 return nil;
1624 mpfree(cipher);
1625 return strtomp(rpc->arg, nil, 16, nil);
1628 static void
1629 factotum_rsa_close(AuthRpc*rpc)
1631 if(!rpc)
1632 return;
1633 close(rpc->afd);
1634 auth_freerpc(rpc);
1637 static void
1638 tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1640 uchar ai[MD5dlen], tmp[MD5dlen];
1641 int i, n;
1642 MD5state *s;
1644 // generate a1
1645 s = hmac_md5(label, nlabel, key, nkey, nil, nil);
1646 s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1647 hmac_md5(seed1, nseed1, key, nkey, ai, s);
1649 while(nbuf > 0) {
1650 s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
1651 s = hmac_md5(label, nlabel, key, nkey, nil, s);
1652 s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
1653 hmac_md5(seed1, nseed1, key, nkey, tmp, s);
1654 n = MD5dlen;
1655 if(n > nbuf)
1656 n = nbuf;
1657 for(i = 0; i < n; i++)
1658 buf[i] ^= tmp[i];
1659 buf += n;
1660 nbuf -= n;
1661 hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
1662 memmove(ai, tmp, MD5dlen);
1666 static void
1667 tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1669 uchar ai[SHA1dlen], tmp[SHA1dlen];
1670 int i, n;
1671 SHAstate *s;
1673 // generate a1
1674 s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
1675 s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1676 hmac_sha1(seed1, nseed1, key, nkey, ai, s);
1678 while(nbuf > 0) {
1679 s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
1680 s = hmac_sha1(label, nlabel, key, nkey, nil, s);
1681 s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
1682 hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
1683 n = SHA1dlen;
1684 if(n > nbuf)
1685 n = nbuf;
1686 for(i = 0; i < n; i++)
1687 buf[i] ^= tmp[i];
1688 buf += n;
1689 nbuf -= n;
1690 hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
1691 memmove(ai, tmp, SHA1dlen);
1695 // fill buf with md5(args)^sha1(args)
1696 static void
1697 tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
1699 int i;
1700 int nlabel = strlen(label);
1701 int n = (nkey + 1) >> 1;
1703 for(i = 0; i < nbuf; i++)
1704 buf[i] = 0;
1705 tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1706 tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
1710 * for setting server session id's
1712 static Lock sidLock;
1713 static long maxSid = 1;
1715 /* the keys are verified to have the same public components
1716 * and to function correctly with pkcs 1 encryption and decryption. */
1717 static TlsSec*
1718 tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
1720 TlsSec *sec = emalloc(sizeof(*sec));
1722 USED(csid); USED(ncsid); // ignore csid for now
1724 memmove(sec->crandom, crandom, RandomSize);
1725 sec->clientVers = cvers;
1727 put32(sec->srandom, time(0));
1728 genrandom(sec->srandom+4, RandomSize-4);
1729 memmove(srandom, sec->srandom, RandomSize);
1732 * make up a unique sid: use our pid, and and incrementing id
1733 * can signal no sid by setting nssid to 0.
1735 memset(ssid, 0, SidSize);
1736 put32(ssid, getpid());
1737 lock(&sidLock);
1738 put32(ssid+4, maxSid++);
1739 unlock(&sidLock);
1740 *nssid = SidSize;
1741 return sec;
1744 static int
1745 tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd)
1747 if(epm != nil){
1748 if(setVers(sec, vers) < 0)
1749 goto Err;
1750 serverMasterSecret(sec, epm, nepm);
1751 }else if(sec->vers != vers){
1752 werrstr("mismatched session versions");
1753 goto Err;
1755 setSecrets(sec, kd, nkd);
1756 return 0;
1757 Err:
1758 sec->ok = -1;
1759 return -1;
1762 static TlsSec*
1763 tlsSecInitc(int cvers, uchar *crandom)
1765 TlsSec *sec = emalloc(sizeof(*sec));
1766 sec->clientVers = cvers;
1767 put32(sec->crandom, time(0));
1768 genrandom(sec->crandom+4, RandomSize-4);
1769 memmove(crandom, sec->crandom, RandomSize);
1770 return sec;
1773 static int
1774 tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd)
1776 RSApub *pub;
1778 pub = nil;
1780 USED(sid);
1781 USED(nsid);
1783 memmove(sec->srandom, srandom, RandomSize);
1785 if(setVers(sec, vers) < 0)
1786 goto Err;
1788 pub = X509toRSApub(cert, ncert, nil, 0);
1789 if(pub == nil){
1790 werrstr("invalid x509/rsa certificate");
1791 goto Err;
1793 if(clientMasterSecret(sec, pub, epm, nepm) < 0)
1794 goto Err;
1795 rsapubfree(pub);
1796 setSecrets(sec, kd, nkd);
1797 return 0;
1799 Err:
1800 if(pub != nil)
1801 rsapubfree(pub);
1802 sec->ok = -1;
1803 return -1;
1806 static int
1807 tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient)
1809 if(sec->nfin != nfin){
1810 sec->ok = -1;
1811 werrstr("invalid finished exchange");
1812 return -1;
1814 md5.malloced = 0;
1815 sha1.malloced = 0;
1816 (*sec->setFinished)(sec, md5, sha1, fin, isclient);
1817 return 1;
1820 static void
1821 tlsSecOk(TlsSec *sec)
1823 if(sec->ok == 0)
1824 sec->ok = 1;
1827 static void
1828 tlsSecKill(TlsSec *sec)
1830 if(!sec)
1831 return;
1832 factotum_rsa_close(sec->rpc);
1833 sec->ok = -1;
1836 static void
1837 tlsSecClose(TlsSec *sec)
1839 if(!sec)
1840 return;
1841 factotum_rsa_close(sec->rpc);
1842 free(sec->server);
1843 free(sec);
1846 static int
1847 setVers(TlsSec *sec, int v)
1849 if(v == SSL3Version){
1850 sec->setFinished = sslSetFinished;
1851 sec->nfin = SSL3FinishedLen;
1852 sec->prf = sslPRF;
1853 }else if(v == TLSVersion){
1854 sec->setFinished = tlsSetFinished;
1855 sec->nfin = TLSFinishedLen;
1856 sec->prf = tlsPRF;
1857 }else{
1858 werrstr("invalid version");
1859 return -1;
1861 sec->vers = v;
1862 return 0;
1866 * generate secret keys from the master secret.
1868 * different crypto selections will require different amounts
1869 * of key expansion and use of key expansion data,
1870 * but it's all generated using the same function.
1872 static void
1873 setSecrets(TlsSec *sec, uchar *kd, int nkd)
1875 (*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
1876 sec->srandom, RandomSize, sec->crandom, RandomSize);
1880 * set the master secret from the pre-master secret.
1882 static void
1883 setMasterSecret(TlsSec *sec, Bytes *pm)
1885 (*sec->prf)(sec->sec, MasterSecretSize, pm->data, MasterSecretSize, "master secret",
1886 sec->crandom, RandomSize, sec->srandom, RandomSize);
1889 static void
1890 serverMasterSecret(TlsSec *sec, uchar *epm, int nepm)
1892 Bytes *pm;
1894 pm = pkcs1_decrypt(sec, epm, nepm);
1896 // if the client messed up, just continue as if everything is ok,
1897 // to prevent attacks to check for correctly formatted messages.
1898 // Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client.
1899 if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
1900 fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
1901 sec->ok, pm, pm ? get16(pm->data) : -1, sec->clientVers, nepm);
1902 sec->ok = -1;
1903 if(pm != nil)
1904 freebytes(pm);
1905 pm = newbytes(MasterSecretSize);
1906 genrandom(pm->data, MasterSecretSize);
1908 setMasterSecret(sec, pm);
1909 memset(pm->data, 0, pm->len);
1910 freebytes(pm);
1913 static int
1914 clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm)
1916 Bytes *pm, *key;
1918 pm = newbytes(MasterSecretSize);
1919 put16(pm->data, sec->clientVers);
1920 genrandom(pm->data+2, MasterSecretSize - 2);
1922 setMasterSecret(sec, pm);
1924 key = pkcs1_encrypt(pm, pub, 2);
1925 memset(pm->data, 0, pm->len);
1926 freebytes(pm);
1927 if(key == nil){
1928 werrstr("tls pkcs1_encrypt failed");
1929 return -1;
1932 *nepm = key->len;
1933 *epm = malloc(*nepm);
1934 if(*epm == nil){
1935 freebytes(key);
1936 werrstr("out of memory");
1937 return -1;
1939 memmove(*epm, key->data, *nepm);
1941 freebytes(key);
1943 return 1;
1946 static void
1947 sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
1949 DigestState *s;
1950 uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
1951 char *label;
1953 if(isClient)
1954 label = "CLNT";
1955 else
1956 label = "SRVR";
1958 md5((uchar*)label, 4, nil, &hsmd5);
1959 md5(sec->sec, MasterSecretSize, nil, &hsmd5);
1960 memset(pad, 0x36, 48);
1961 md5(pad, 48, nil, &hsmd5);
1962 md5(nil, 0, h0, &hsmd5);
1963 memset(pad, 0x5C, 48);
1964 s = md5(sec->sec, MasterSecretSize, nil, nil);
1965 s = md5(pad, 48, nil, s);
1966 md5(h0, MD5dlen, finished, s);
1968 sha1((uchar*)label, 4, nil, &hssha1);
1969 sha1(sec->sec, MasterSecretSize, nil, &hssha1);
1970 memset(pad, 0x36, 40);
1971 sha1(pad, 40, nil, &hssha1);
1972 sha1(nil, 0, h1, &hssha1);
1973 memset(pad, 0x5C, 40);
1974 s = sha1(sec->sec, MasterSecretSize, nil, nil);
1975 s = sha1(pad, 40, nil, s);
1976 sha1(h1, SHA1dlen, finished + MD5dlen, s);
1979 // fill "finished" arg with md5(args)^sha1(args)
1980 static void
1981 tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
1983 uchar h0[MD5dlen], h1[SHA1dlen];
1984 char *label;
1986 // get current hash value, but allow further messages to be hashed in
1987 md5(nil, 0, h0, &hsmd5);
1988 sha1(nil, 0, h1, &hssha1);
1990 if(isClient)
1991 label = "client finished";
1992 else
1993 label = "server finished";
1994 tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
1997 static void
1998 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2000 DigestState *s;
2001 uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2002 int i, n, len;
2004 USED(label);
2005 len = 1;
2006 while(nbuf > 0){
2007 if(len > 26)
2008 return;
2009 for(i = 0; i < len; i++)
2010 tmp[i] = 'A' - 1 + len;
2011 s = sha1(tmp, len, nil, nil);
2012 s = sha1(key, nkey, nil, s);
2013 s = sha1(seed0, nseed0, nil, s);
2014 sha1(seed1, nseed1, sha1dig, s);
2015 s = md5(key, nkey, nil, nil);
2016 md5(sha1dig, SHA1dlen, md5dig, s);
2017 n = MD5dlen;
2018 if(n > nbuf)
2019 n = nbuf;
2020 memmove(buf, md5dig, n);
2021 buf += n;
2022 nbuf -= n;
2023 len++;
2027 static mpint*
2028 bytestomp(Bytes* bytes)
2030 mpint* ans;
2032 ans = betomp(bytes->data, bytes->len, nil);
2033 return ans;
2037 * Convert mpint* to Bytes, putting high order byte first.
2039 static Bytes*
2040 mptobytes(mpint* big)
2042 int n, m;
2043 uchar *a;
2044 Bytes* ans;
2046 n = (mpsignif(big)+7)/8;
2047 m = mptobe(big, nil, n, &a);
2048 ans = makebytes(a, m);
2049 return ans;
2052 // Do RSA computation on block according to key, and pad
2053 // result on left with zeros to make it modlen long.
2054 static Bytes*
2055 rsacomp(Bytes* block, RSApub* key, int modlen)
2057 mpint *x, *y;
2058 Bytes *a, *ybytes;
2059 int ylen;
2061 x = bytestomp(block);
2062 y = rsaencrypt(key, x, nil);
2063 mpfree(x);
2064 ybytes = mptobytes(y);
2065 ylen = ybytes->len;
2067 if(ylen < modlen) {
2068 a = newbytes(modlen);
2069 memset(a->data, 0, modlen-ylen);
2070 memmove(a->data+modlen-ylen, ybytes->data, ylen);
2071 freebytes(ybytes);
2072 ybytes = a;
2074 else if(ylen > modlen) {
2075 // assume it has leading zeros (mod should make it so)
2076 a = newbytes(modlen);
2077 memmove(a->data, ybytes->data, modlen);
2078 freebytes(ybytes);
2079 ybytes = a;
2081 mpfree(y);
2082 return ybytes;
2085 // encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
2086 static Bytes*
2087 pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
2089 Bytes *pad, *eb, *ans;
2090 int i, dlen, padlen, modlen;
2092 modlen = (mpsignif(key->n)+7)/8;
2093 dlen = data->len;
2094 if(modlen < 12 || dlen > modlen - 11)
2095 return nil;
2096 padlen = modlen - 3 - dlen;
2097 pad = newbytes(padlen);
2098 genrandom(pad->data, padlen);
2099 for(i = 0; i < padlen; i++) {
2100 if(blocktype == 0)
2101 pad->data[i] = 0;
2102 else if(blocktype == 1)
2103 pad->data[i] = 255;
2104 else if(pad->data[i] == 0)
2105 pad->data[i] = 1;
2107 eb = newbytes(modlen);
2108 eb->data[0] = 0;
2109 eb->data[1] = blocktype;
2110 memmove(eb->data+2, pad->data, padlen);
2111 eb->data[padlen+2] = 0;
2112 memmove(eb->data+padlen+3, data->data, dlen);
2113 ans = rsacomp(eb, key, modlen);
2114 freebytes(eb);
2115 freebytes(pad);
2116 return ans;
2119 // decrypt data according to PKCS#1, with given key.
2120 // expect a block type of 2.
2121 static Bytes*
2122 pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm)
2124 Bytes *eb, *ans = nil;
2125 int i, modlen;
2126 mpint *x, *y;
2128 modlen = (mpsignif(sec->rsapub->n)+7)/8;
2129 if(nepm != modlen)
2130 return nil;
2131 x = betomp(epm, nepm, nil);
2132 y = factotum_rsa_decrypt(sec->rpc, x);
2133 if(y == nil)
2134 return nil;
2135 eb = mptobytes(y);
2136 if(eb->len < modlen){ // pad on left with zeros
2137 ans = newbytes(modlen);
2138 memset(ans->data, 0, modlen-eb->len);
2139 memmove(ans->data+modlen-eb->len, eb->data, eb->len);
2140 freebytes(eb);
2141 eb = ans;
2143 if(eb->data[0] == 0 && eb->data[1] == 2) {
2144 for(i = 2; i < modlen; i++)
2145 if(eb->data[i] == 0)
2146 break;
2147 if(i < modlen - 1)
2148 ans = makebytes(eb->data+i+1, modlen-(i+1));
2150 freebytes(eb);
2151 return ans;
2155 //================= general utility functions ========================
2157 static void *
2158 emalloc(int n)
2160 void *p;
2161 if(n==0)
2162 n=1;
2163 p = malloc(n);
2164 if(p == nil){
2165 exits("out of memory");
2167 memset(p, 0, n);
2168 return p;
2171 static void *
2172 erealloc(void *ReallocP, int ReallocN)
2174 if(ReallocN == 0)
2175 ReallocN = 1;
2176 if(!ReallocP)
2177 ReallocP = emalloc(ReallocN);
2178 else if(!(ReallocP = realloc(ReallocP, ReallocN))){
2179 exits("out of memory");
2181 return(ReallocP);
2184 static void
2185 put32(uchar *p, u32int x)
2187 p[0] = x>>24;
2188 p[1] = x>>16;
2189 p[2] = x>>8;
2190 p[3] = x;
2193 static void
2194 put24(uchar *p, int x)
2196 p[0] = x>>16;
2197 p[1] = x>>8;
2198 p[2] = x;
2201 static void
2202 put16(uchar *p, int x)
2204 p[0] = x>>8;
2205 p[1] = x;
2208 static u32int
2209 get32(uchar *p)
2211 return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2214 static int
2215 get24(uchar *p)
2217 return (p[0]<<16)|(p[1]<<8)|p[2];
2220 static int
2221 get16(uchar *p)
2223 return (p[0]<<8)|p[1];
2226 /* ANSI offsetof() */
2227 #define OFFSET(x, s) ((int)(&(((s*)0)->x)))
2230 * malloc and return a new Bytes structure capable of
2231 * holding len bytes. (len >= 0)
2232 * Used to use crypt_malloc, which aborts if malloc fails.
2234 static Bytes*
2235 newbytes(int len)
2237 Bytes* ans;
2239 ans = (Bytes*)malloc(OFFSET(data[0], Bytes) + len);
2240 ans->len = len;
2241 return ans;
2245 * newbytes(len), with data initialized from buf
2247 static Bytes*
2248 makebytes(uchar* buf, int len)
2250 Bytes* ans;
2252 ans = newbytes(len);
2253 memmove(ans->data, buf, len);
2254 return ans;
2257 static void
2258 freebytes(Bytes* b)
2260 if(b != nil)
2261 free(b);
2264 /* len is number of ints */
2265 static Ints*
2266 newints(int len)
2268 Ints* ans;
2270 ans = (Ints*)malloc(OFFSET(data[0], Ints) + len*sizeof(int));
2271 ans->len = len;
2272 return ans;
2275 static Ints*
2276 makeints(int* buf, int len)
2278 Ints* ans;
2280 ans = newints(len);
2281 if(len > 0)
2282 memmove(ans->data, buf, len*sizeof(int));
2283 return ans;
2286 static void
2287 freeints(Ints* b)
2289 if(b != nil)
2290 free(b);