Blob


1 #include "os.h"
2 #include <mp.h>
3 #include "dat.h"
5 /* */
6 /* from knuth's 1969 seminumberical algorithms, pp 233-235 and pp 258-260 */
7 /* */
8 /* mpvecmul is an assembly language routine that performs the inner */
9 /* loop. */
10 /* */
11 /* the karatsuba trade off is set empiricly by measuring the algs on */
12 /* a 400 MHz Pentium II. */
13 /* */
15 /* karatsuba like (see knuth pg 258) */
16 /* prereq: p is already zeroed */
17 static void
18 mpkaratsuba(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
19 {
20 mpdigit *t, *u0, *u1, *v0, *v1, *u0v0, *u1v1, *res, *diffprod;
21 int u0len, u1len, v0len, v1len, reslen;
22 int sign, n;
24 /* divide each piece in half */
25 n = alen/2;
26 if(alen&1)
27 n++;
28 u0len = n;
29 u1len = alen-n;
30 if(blen > n){
31 v0len = n;
32 v1len = blen-n;
33 } else {
34 v0len = blen;
35 v1len = 0;
36 }
37 u0 = a;
38 u1 = a + u0len;
39 v0 = b;
40 v1 = b + v0len;
42 /* room for the partial products */
43 t = mallocz(Dbytes*5*(2*n+1), 1);
44 if(t == nil)
45 sysfatal("mpkaratsuba: %r");
46 u0v0 = t;
47 u1v1 = t + (2*n+1);
48 diffprod = t + 2*(2*n+1);
49 res = t + 3*(2*n+1);
50 reslen = 4*n+1;
52 /* t[0] = (u1-u0) */
53 sign = 1;
54 if(mpveccmp(u1, u1len, u0, u0len) < 0){
55 sign = -1;
56 mpvecsub(u0, u0len, u1, u1len, u0v0);
57 } else
58 mpvecsub(u1, u1len, u0, u1len, u0v0);
60 /* t[1] = (v0-v1) */
61 if(mpveccmp(v0, v0len, v1, v1len) < 0){
62 sign *= -1;
63 mpvecsub(v1, v1len, v0, v1len, u1v1);
64 } else
65 mpvecsub(v0, v0len, v1, v1len, u1v1);
67 /* t[4:5] = (u1-u0)*(v0-v1) */
68 mpvecmul(u0v0, u0len, u1v1, v0len, diffprod);
70 /* t[0:1] = u1*v1 */
71 memset(t, 0, 2*(2*n+1)*Dbytes);
72 if(v1len > 0)
73 mpvecmul(u1, u1len, v1, v1len, u1v1);
75 /* t[2:3] = u0v0 */
76 mpvecmul(u0, u0len, v0, v0len, u0v0);
78 /* res = u0*v0<<n + u0*v0 */
79 mpvecadd(res, reslen, u0v0, u0len+v0len, res);
80 mpvecadd(res+n, reslen-n, u0v0, u0len+v0len, res+n);
82 /* res += u1*v1<<n + u1*v1<<2*n */
83 if(v1len > 0){
84 mpvecadd(res+n, reslen-n, u1v1, u1len+v1len, res+n);
85 mpvecadd(res+2*n, reslen-2*n, u1v1, u1len+v1len, res+2*n);
86 }
88 /* res += (u1-u0)*(v0-v1)<<n */
89 if(sign < 0)
90 mpvecsub(res+n, reslen-n, diffprod, u0len+v0len, res+n);
91 else
92 mpvecadd(res+n, reslen-n, diffprod, u0len+v0len, res+n);
93 memmove(p, res, (alen+blen)*Dbytes);
95 free(t);
96 }
98 #define KARATSUBAMIN 32
100 void
101 mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
103 int i;
104 mpdigit d;
105 mpdigit *t;
107 /* both mpvecdigmuladd and karatsuba are fastest when a is the longer vector */
108 if(alen < blen){
109 i = alen;
110 alen = blen;
111 blen = i;
112 t = a;
113 a = b;
114 b = t;
116 if(blen == 0){
117 memset(p, 0, Dbytes*(alen+blen));
118 return;
121 if(alen >= KARATSUBAMIN && blen > 1){
122 /* O(n^1.585) */
123 mpkaratsuba(a, alen, b, blen, p);
124 } else {
125 /* O(n^2) */
126 for(i = 0; i < blen; i++){
127 d = b[i];
128 if(d != 0)
129 mpvecdigmuladd(a, alen, d, &p[i]);
134 void
135 mpmul(mpint *b1, mpint *b2, mpint *prod)
137 mpint *oprod;
139 oprod = nil;
140 if(prod == b1 || prod == b2){
141 oprod = prod;
142 prod = mpnew(0);
145 prod->top = 0;
146 mpbits(prod, (b1->top+b2->top+1)*Dbits);
147 mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p);
148 prod->top = b1->top+b2->top+1;
149 prod->sign = b1->sign*b2->sign;
150 mpnorm(prod);
152 if(oprod != nil){
153 mpassign(prod, oprod);
154 mpfree(prod);