Blob


1 #include <u.h>
2 #include <libc.h>
3 #include <mp.h>
4 #include <libsec.h>
5 #include "SConn.h"
7 extern int verbose;
9 typedef struct ConnState {
10 uchar secret[SHA1dlen];
11 ulong seqno;
12 RC4state rc4;
13 } ConnState;
15 typedef struct SS{
16 int fd; // file descriptor for read/write of encrypted data
17 int alg; // if nonzero, "alg sha rc4_128"
18 ConnState in, out;
19 } SS;
21 static int
22 SC_secret(SConn *conn, uchar *sigma, int direction)
23 {
24 SS *ss = (SS*)(conn->chan);
25 int nsigma = conn->secretlen;
27 if(direction != 0){
28 hmac_sha1(sigma, nsigma, (uchar*)"one", 3, ss->out.secret, nil);
29 hmac_sha1(sigma, nsigma, (uchar*)"two", 3, ss->in.secret, nil);
30 }else{
31 hmac_sha1(sigma, nsigma, (uchar*)"two", 3, ss->out.secret, nil);
32 hmac_sha1(sigma, nsigma, (uchar*)"one", 3, ss->in.secret, nil);
33 }
34 setupRC4state(&ss->in.rc4, ss->in.secret, 16); // restrict to 128 bits
35 setupRC4state(&ss->out.rc4, ss->out.secret, 16);
36 ss->alg = 1;
37 return 0;
38 }
40 static void
41 hash(uchar secret[SHA1dlen], uchar *data, int len, int seqno, uchar d[SHA1dlen])
42 {
43 DigestState sha;
44 uchar seq[4];
46 seq[0] = seqno>>24;
47 seq[1] = seqno>>16;
48 seq[2] = seqno>>8;
49 seq[3] = seqno;
50 memset(&sha, 0, sizeof sha);
51 sha1(secret, SHA1dlen, nil, &sha);
52 sha1(data, len, nil, &sha);
53 sha1(seq, 4, d, &sha);
54 }
56 static int
57 verify(uchar secret[SHA1dlen], uchar *data, int len, int seqno, uchar d[SHA1dlen])
58 {
59 DigestState sha;
60 uchar seq[4];
61 uchar digest[SHA1dlen];
63 seq[0] = seqno>>24;
64 seq[1] = seqno>>16;
65 seq[2] = seqno>>8;
66 seq[3] = seqno;
67 memset(&sha, 0, sizeof sha);
68 sha1(secret, SHA1dlen, nil, &sha);
69 sha1(data, len, nil, &sha);
70 sha1(seq, 4, digest, &sha);
71 return memcmp(d, digest, SHA1dlen);
72 }
74 static int
75 SC_read(SConn *conn, uchar *buf, int n)
76 {
77 SS *ss = (SS*)(conn->chan);
78 uchar count[2], digest[SHA1dlen];
79 int len, nr;
81 if(read(ss->fd, count, 2) != 2 || (count[0]&0x80) == 0){
82 snprint((char*)buf,n,"!SC_read invalid count");
83 return -1;
84 }
85 len = (count[0]&0x7f)<<8 | count[1]; // SSL-style count; no pad
86 if(ss->alg){
87 len -= SHA1dlen;
88 if(len <= 0 || readn(ss->fd, digest, SHA1dlen) != SHA1dlen){
89 snprint((char*)buf,n,"!SC_read missing sha1");
90 return -1;
91 }
92 if(len > n || readn(ss->fd, buf, len) != len){
93 snprint((char*)buf,n,"!SC_read missing data");
94 return -1;
95 }
96 rc4(&ss->in.rc4, digest, SHA1dlen);
97 rc4(&ss->in.rc4, buf, len);
98 if(verify(ss->in.secret, buf, len, ss->in.seqno, digest) != 0){
99 snprint((char*)buf,n,"!SC_read integrity check failed");
100 return -1;
102 }else{
103 if(len <= 0 || len > n){
104 snprint((char*)buf,n,"!SC_read implausible record length");
105 return -1;
107 if( (nr = readn(ss->fd, buf, len)) != len){
108 snprint((char*)buf,n,"!SC_read expected %d bytes, but got %d", len, nr);
109 return -1;
112 ss->in.seqno++;
113 return len;
116 static int
117 SC_write(SConn *conn, uchar *buf, int n)
119 SS *ss = (SS*)(conn->chan);
120 uchar count[2], digest[SHA1dlen], enc[Maxmsg+1];
121 int len;
123 if(n <= 0 || n > Maxmsg+1){
124 werrstr("!SC_write invalid n %d", n);
125 return -1;
127 len = n;
128 if(ss->alg)
129 len += SHA1dlen;
130 count[0] = 0x80 | len>>8;
131 count[1] = len;
132 if(write(ss->fd, count, 2) != 2){
133 werrstr("!SC_write invalid count");
134 return -1;
136 if(ss->alg){
137 hash(ss->out.secret, buf, n, ss->out.seqno, digest);
138 rc4(&ss->out.rc4, digest, SHA1dlen);
139 memcpy(enc, buf, n);
140 rc4(&ss->out.rc4, enc, n);
141 if(write(ss->fd, digest, SHA1dlen) != SHA1dlen ||
142 write(ss->fd, enc, n) != n){
143 werrstr("!SC_write error on send");
144 return -1;
146 }else{
147 if(write(ss->fd, buf, n) != n){
148 werrstr("!SC_write error on send");
149 return -1;
152 ss->out.seqno++;
153 return n;
156 static void
157 SC_free(SConn *conn)
159 SS *ss = (SS*)(conn->chan);
161 close(ss->fd);
162 free(ss);
163 free(conn);
166 SConn*
167 newSConn(int fd)
169 SS *ss;
170 SConn *conn;
172 if(fd < 0)
173 return nil;
174 ss = (SS*)emalloc(sizeof(*ss));
175 conn = (SConn*)emalloc(sizeof(*conn));
176 ss->fd = fd;
177 ss->alg = 0;
178 conn->chan = (void*)ss;
179 conn->secretlen = SHA1dlen;
180 conn->free = SC_free;
181 conn->secret = SC_secret;
182 conn->read = SC_read;
183 conn->write = SC_write;
184 return conn;
187 void
188 writerr(SConn *conn, char *s)
190 char buf[Maxmsg];
192 snprint(buf, Maxmsg, "!%s", s);
193 conn->write(conn, (uchar*)buf, strlen(buf));
196 int
197 readstr(SConn *conn, char *s)
199 int n;
201 n = conn->read(conn, (uchar*)s, Maxmsg);
202 if(n >= 0){
203 s[n] = 0;
204 if(s[0] == '!'){
205 memmove(s, s+1, n);
206 n = -1;
208 }else{
209 strcpy(s, "read error");
211 return n;