diff --git a/bn.c b/bn.c index 6d258d0..0debe07 100644 --- a/bn.c +++ b/bn.c @@ -25,17 +25,116 @@ static const char *s_rmap = #ifdef DEBUG -static char *_funcs[1000]; -int _ifuncs; +/* timing data */ +#ifdef TIMER_X86 +extern ulong64 gettsc(void); +#else +ulong64 gettsc(void) { return clock(); } +#endif -#define REGFUNC(name) { if (_ifuncs == 999) { printf("TROUBLE\n"); exit(0); } _funcs[_ifuncs++] = name; } -#define DECFUNC() --_ifuncs; +/* structure to hold timing data */ +struct { + char *func; + ulong64 start, end, tot; +} timings[1000000]; + +/* structure to hold consolidated timing data */ +struct _functime { + char *func; + ulong64 tot; +} functime[1000]; + +static char *_funcs[1000]; +int _ifuncs, _itims; + +#define REGFUNC(name) int __IX = _itims++; _funcs[_ifuncs++] = name; timings[__IX].func = name; timings[__IX].start = gettsc(); +#define DECFUNC() timings[__IX].end = gettsc(); --_ifuncs; #define VERIFY(val) _verify(val, #val, __LINE__); +/* sort the consolidated timings */ +int qsort_helper(const void *A, const void *B) +{ + struct _functime *a, *b; + + a = (struct _functime *)A; + b = (struct _functime *)B; + + if (a->tot > b->tot) return -1; + if (a->tot < b->tot) return 1; + return 0; +} + +/* reset debugging information */ +void reset_timings(void) +{ + _ifuncs = _itims = 0; +} + +/* dump the timing data */ +void dump_timings(void) +{ + int x, y; + ulong64 total; + + /* first for every find the total time */ + printf("Phase I ... Finding totals (%d samples)...\n", _itims); + for (x = 0; x < _itims; x++) { + timings[x].tot = timings[x].end - timings[x].start; + } + + /* now subtract the time for each function where nested functions occured */ + printf("Phase II ... Finding dependencies...\n"); + for (x = 0; x < _itims-1; x++) { + for (y = x+1; y < _itims && timings[y].start <= timings[x].end; y++) { + timings[x].tot -= timings[y].tot; + if (timings[x].tot > ((ulong64)1 << (ulong64)40)) { + timings[x].tot = 0; + } + } + } + + /* now consolidate all the entries */ + printf("Phase III... Consolidation...\n"); + + memset(&functime, 0, sizeof(functime)); + total = 0; + for (x = 0; x < _itims; x++) { + total += timings[x].tot; + + /* try to find this entry */ + for (y = 0; functime[y].func != NULL; y++) { + if (strcmp(timings[x].func, functime[y].func) == 0) { + break; + } + } + + if (functime[y].func == NULL) { + /* new entry */ + functime[y].func = timings[x].func; + functime[y].tot = timings[x].tot; + } else { + functime[y].tot += timings[x].tot; + } + } + + for (x = 0; functime[x].func != NULL; x++); + + /* sort and dump */ + qsort(&functime, x, sizeof(functime[0]), &qsort_helper); + + for (x = 0; functime[x].func != NULL; x++) { + if (functime[x].tot > 0 && strcmp(functime[x].func, "_verify") != 0) { + printf("%30s: %20llu (%3llu.%03llu %%)\n", functime[x].func, functime[x].tot, (functime[x].tot * (ulong64)100) / total, ((functime[x].tot * (ulong64)100000) / total) % (ulong64)1000); + } + } +} + static void _verify(mp_int *a, char *name, int line) { int n, y; static const char *err[] = { "Null DP", "alloc < used", "digits above used" }; + + REGFUNC("_verify"); /* dp null ? */ y = 0; @@ -52,6 +151,7 @@ static void _verify(mp_int *a, char *name, int line) } /* ok */ + DECFUNC(); return; error: printf("Error (%s) with variable {%s} on line %d\n", err[y], name, line); @@ -64,7 +164,7 @@ error: exit(0); } -#else +#else /* don't use DEBUG stuff so these macros are blank */ #define REGFUNC(name) #define DECFUNC() @@ -76,13 +176,18 @@ error: int mp_init(mp_int *a) { REGFUNC("mp_init"); - a->dp = calloc(sizeof(mp_digit), 16); + + /* allocate ram required and clear it */ + a->dp = calloc(sizeof(mp_digit), MP_PREC); if (a->dp == NULL) { DECFUNC(); return MP_MEM; } + + /* set the used to zero, allocated digit to the default precision + * and sign to positive */ a->used = 0; - a->alloc = 16; + a->alloc = MP_PREC; a->sign = MP_ZPOS; VERIFY(a); @@ -96,8 +201,14 @@ void mp_clear(mp_int *a) REGFUNC("mp_clear"); if (a->dp != NULL) { VERIFY(a); - memset(a->dp, 0, sizeof(mp_digit) * a->alloc); + + /* first zero the digits */ + memset(a->dp, 0, sizeof(mp_digit) * a->used); + + /* free ram */ free(a->dp); + + /* reset members to make debugging easier */ a->dp = NULL; a->alloc = a->used = 0; } @@ -118,27 +229,26 @@ void mp_exch(mp_int *a, mp_int *b) /* grow as required */ static int mp_grow(mp_int *a, int size) { - int i; - mp_digit *tmp; + int i, n; REGFUNC("mp_grow"); VERIFY(a); /* if the alloc size is smaller alloc more ram */ if (a->alloc < size) { - size += 32 - (size & 15); /* ensure there are always at least 16 digits extra on top */ + size += (MP_PREC*2) - (size & (MP_PREC-1)); /* ensure there are always at least 16 digits extra on top */ - tmp = calloc(sizeof(mp_digit), size); - if (tmp == NULL) { + a->dp = realloc(a->dp, sizeof(mp_digit)*size); + if (a->dp == NULL) { DECFUNC(); return MP_MEM; } - for (i = 0; i < a->used; i++) { - tmp[i] = a->dp[i]; - } - free(a->dp); - a->dp = tmp; + + n = a->alloc; a->alloc = size; + for (i = n; i < a->alloc; i++) { + a->dp[i] = 0; + } } DECFUNC(); return MP_OKAY; @@ -233,8 +343,7 @@ int mp_init_size(mp_int *a, int size) REGFUNC("mp_init_size"); /* pad up so there are at least 16 zero digits */ - size += 32 - (size & 15); - + size += (MP_PREC*2) - (size & (MP_PREC-1)); /* ensure there are always at least 16 digits extra on top */ a->dp = calloc(sizeof(mp_digit), size); if (a->dp == NULL) { DECFUNC(); @@ -270,7 +379,6 @@ int mp_copy(mp_int *a, mp_int *b) } /* zero b and copy the parameters over */ - mp_zero(b); b->used = a->used; b->sign = a->sign; @@ -278,6 +386,11 @@ int mp_copy(mp_int *a, mp_int *b) for (n = 0; n < a->used; n++) { b->dp[n] = a->dp[n]; } + + /* clear high digits */ + for (n = b->used; n < b->alloc; n++) { + b->dp[n] = 0; + } DECFUNC(); return MP_OKAY; } @@ -513,7 +626,7 @@ int mp_mod_2d(mp_int *a, int b, mp_int *c) c->dp[x] = 0; } /* clear the digit that is not completely outside/inside the modulus */ - c->dp[b/DIGIT_BIT] &= (mp_digit)((((mp_digit)1)<<(b % DIGIT_BIT)) - ((mp_digit)1)); + c->dp[b/DIGIT_BIT] &= (mp_digit)((((mp_digit)1)<<(((mp_digit)b) % DIGIT_BIT)) - ((mp_digit)1)); mp_clamp(c); DECFUNC(); return MP_OKAY; @@ -697,7 +810,7 @@ int mp_mul_2(mp_int *a, mp_int *b) return MP_OKAY; } -/* low level addition */ +/* low level addition, based on HAC pp.594, Algorithm 14.7 */ static int s_mp_add(mp_int *a, mp_int *b, mp_int *c) { mp_int *x; @@ -709,7 +822,9 @@ static int s_mp_add(mp_int *a, mp_int *b, mp_int *c) VERIFY(b); VERIFY(c); - /* find sizes */ + /* find sizes, we let |a| <= |b| which means we have to sort + * them. "x" will point to the input with the most digits + */ if (a->used > b->used) { min = b->used; max = a->used; @@ -735,9 +850,11 @@ static int s_mp_add(mp_int *a, mp_int *b, mp_int *c) c->used = max + 1; /* add digits from lower part */ + + /* set the carry to zero */ u = 0; for (i = 0; i < min; i++) { - /* T[i] = A[i] + B[i] + U */ + /* Compute the sum at one digit, T[i] = A[i] + B[i] + U */ c->dp[i] = a->dp[i] + b->dp[i] + u; /* U = carry bit of T[i] */ @@ -774,7 +891,7 @@ static int s_mp_add(mp_int *a, mp_int *b, mp_int *c) return MP_OKAY; } -/* low level subtraction (assumes a > b) */ +/* low level subtraction (assumes a > b), HAC pp.595 Algorithm 14.9 */ static int s_mp_sub(mp_int *a, mp_int *b, mp_int *c) { int olduse, res, min, max, i; @@ -800,6 +917,8 @@ static int s_mp_sub(mp_int *a, mp_int *b, mp_int *c) c->used = max; /* sub digits from lower part */ + + /* set carry to zero */ u = 0; for (i = 0; i < min; i++) { /* T[i] = A[i] - B[i] - U */ @@ -849,9 +968,8 @@ static int s_mp_sub(mp_int *a, mp_int *b, mp_int *c) */ static int fast_s_mp_mul_digs(mp_int *a, mp_int *b, mp_int *c, int digs) { - int olduse, res, pa, pb, ix, iy; - mp_word W[512], *_W; - mp_digit tmpx, *tmpy; + int olduse, res, pa, ix; + mp_word W[512]; REGFUNC("fast_s_mp_mul_digs"); VERIFY(a); @@ -866,7 +984,7 @@ static int fast_s_mp_mul_digs(mp_int *a, mp_int *b, mp_int *c, int digs) } /* clear temp buf (the columns) */ - memset(W, 0, digs*sizeof(mp_word)); + memset(W, 0, sizeof(mp_word) * digs); /* calculate the columns */ pa = a->used; @@ -876,21 +994,41 @@ static int fast_s_mp_mul_digs(mp_int *a, mp_int *b, mp_int *c, int digs) * of output are produced. So at most we want to make upto "digs" digits * of output */ - pb = MIN(b->used, digs - ix); - /* setup some pointer aliases to simplify the inner loop */ - tmpx = a->dp[ix]; - tmpy = b->dp; - _W = &(W[ix]); /* this adds products to distinct columns (at ix+iy) of W * note that each step through the loop is not dependent on * the previous which means the compiler can easily unroll * the loop without scheduling problems */ - for (iy = 0; iy < pb; iy++) { - *_W++ += ((mp_word)tmpx) * ((mp_word)*tmpy++); + { + register mp_digit tmpx, *tmpy; + register mp_word *_W; + register int iy, pb; + + /* alias for the the word on the left e.g. A[ix] * A[iy] */ + tmpx = a->dp[ix]; + + /* alias for the right side */ + tmpy = b->dp; + + /* alias for the columns, each step through the loop adds a new + term to each column + */ + _W = W + ix; + + + /* the number of digits is limited by their placement. E.g. + we avoid multiplying digits that will end up above the # of + digits of precision requested + */ + pb = MIN(b->used, digs - ix); + + for (iy = 0; iy < pb; iy++) { + *_W++ += ((mp_word)tmpx) * ((mp_word)*tmpy++); + } } + } /* setup dest */ @@ -908,11 +1046,12 @@ static int fast_s_mp_mul_digs(mp_int *a, mp_int *b, mp_int *c, int digs) * N^2 + N*c where c is the cost of the shifting. On very small numbers * this is slower but on most cryptographic size numbers it is faster. */ + for (ix = 1; ix < digs; ix++) { - W[ix] = W[ix] + (W[ix-1] >> ((mp_word)DIGIT_BIT)); - c->dp[ix-1] = W[ix-1] & ((mp_word)MP_MASK); + W[ix] += (W[ix-1] >> ((mp_word)DIGIT_BIT)); + c->dp[ix-1] = (mp_digit)(W[ix-1] & ((mp_word)MP_MASK)); } - c->dp[digs-1] = W[digs-1] & ((mp_word)MP_MASK); + c->dp[digs-1] = (mp_digit)(W[digs-1] & ((mp_word)MP_MASK)); /* clear unused */ for (ix = c->used; ix < olduse; ix++) { @@ -924,7 +1063,10 @@ static int fast_s_mp_mul_digs(mp_int *a, mp_int *b, mp_int *c, int digs) return MP_OKAY; } -/* multiplies |a| * |b| and only computes upto digs digits of result */ +/* multiplies |a| * |b| and only computes upto digs digits of result + * HAC pp. 595, Algorithm 14.12 Modified so you can control how many digits of + * output are created. + */ static int s_mp_mul_digs(mp_int *a, mp_int *b, mp_int *c, int digs) { mp_int t; @@ -963,7 +1105,7 @@ static int s_mp_mul_digs(mp_int *a, mp_int *b, mp_int *c, int digs) /* limit ourselves to making digs digits of output */ pb = MIN(b->used, digs - ix); - + /* setup some aliases */ tmpx = a->dp[ix]; tmpt = &(t.dp[ix]); @@ -1001,9 +1143,8 @@ static int s_mp_mul_digs(mp_int *a, mp_int *b, mp_int *c, int digs) */ static int fast_s_mp_mul_high_digs(mp_int *a, mp_int *b, mp_int *c, int digs) { - int oldused, newused, res, pa, pb, ix, iy; - mp_word W[512], *_W; - mp_digit tmpx, *tmpy; + int oldused, newused, res, pa, pb, ix; + mp_word W[512]; REGFUNC("fast_s_mp_mul_high_digs"); VERIFY(a); @@ -1023,14 +1164,29 @@ static int fast_s_mp_mul_high_digs(mp_int *a, mp_int *b, mp_int *c, int digs) pb = b->used; memset(&W[digs], 0, (pa + pb + 1 - digs) * sizeof(mp_word)); for (ix = 0; ix < pa; ix++) { - /* pointer aliases */ - tmpx = a->dp[ix]; - tmpy = b->dp + (digs - ix); - _W = &(W[digs]); + { + register mp_digit tmpx, *tmpy; + register int iy; + register mp_word *_W; + + /* work todo, that is we only calculate digits that are at "digs" or above */ + iy = digs - ix; + + /* copy of word on the left of A[ix] * B[iy] */ + tmpx = a->dp[ix]; + + /* alias for right side */ + tmpy = b->dp + iy; + + /* alias for the columns of output. Offset to be equal to or above the + * smallest digit place requested + */ + _W = &(W[digs]); - /* compute column products for digits above the minimum */ - for (iy = digs - ix; iy < pb; iy++) { - *_W++ += ((mp_word)tmpx) * ((mp_word)*tmpy++); + /* compute column products for digits above the minimum */ + for (; iy < pb; iy++) { + *_W++ += ((mp_word)tmpx) * ((mp_word)*tmpy++); + } } } @@ -1040,10 +1196,10 @@ static int fast_s_mp_mul_high_digs(mp_int *a, mp_int *b, mp_int *c, int digs) /* now convert the array W downto what we need */ for (ix = digs+1; ix < (pa+pb+1); ix++) { - W[ix] = W[ix] + (W[ix-1] >> ((mp_word)DIGIT_BIT)); - c->dp[ix-1] = W[ix-1] & ((mp_word)MP_MASK); + W[ix] += (W[ix-1] >> ((mp_word)DIGIT_BIT)); + c->dp[ix-1] = (mp_digit)(W[ix-1] & ((mp_word)MP_MASK)); } - c->dp[(pa+pb+1)-1] = W[(pa+pb+1)-1] & ((mp_word)MP_MASK); + c->dp[(pa+pb+1)-1] = (mp_digit)(W[(pa+pb+1)-1] & ((mp_word)MP_MASK)); for (ix = c->used; ix < oldused; ix++) { c->dp[ix] = 0; @@ -1085,13 +1241,26 @@ static int s_mp_mul_high_digs(mp_int *a, mp_int *b, mp_int *c, int digs) pa = a->used; pb = b->used; for (ix = 0; ix < pa; ix++) { + /* clear the carry */ u = 0; + + /* left hand side of A[ix] * B[iy] */ tmpx = a->dp[ix]; + + /* alias to the address of where the digits will be stored */ tmpt = &(t.dp[digs]); + + /* alias for where to read the right hand side from */ tmpy = b->dp + (digs - ix); + for (iy = digs - ix; iy < pb; iy++) { + /* calculate the double precision result */ r = ((mp_word)*tmpt) + ((mp_word)tmpx) * ((mp_word)*tmpy++) + ((mp_word)u); + + /* get the lower part */ *tmpt++ = (mp_digit)(r & ((mp_word)MP_MASK)); + + /* carry the carry */ u = (mp_digit)(r >> ((mp_word)DIGIT_BIT)); } *tmpt = u; @@ -1119,9 +1288,8 @@ static int s_mp_mul_high_digs(mp_int *a, mp_int *b, mp_int *c, int digs) */ static int fast_s_mp_sqr(mp_int *a, mp_int *b) { - int olduse, newused, res, ix, iy, pa; - mp_word W2[512], W[512], *_W; - mp_digit tmpx, *tmpy; + int olduse, newused, res, ix, pa; + mp_word W2[512], W[512]; REGFUNC("fast_s_mp_sqr"); VERIFY(a); @@ -1144,14 +1312,23 @@ static int fast_s_mp_sqr(mp_int *a, mp_int *b) /* compute the outer product */ W2[ix+ix] += ((mp_word)a->dp[ix]) * ((mp_word)a->dp[ix]); - /* pointer aliasing! */ - tmpx = a->dp[ix]; - tmpy = &(a->dp[ix+1]); - _W = &(W[ix+ix+1]); + { + register mp_digit tmpx, *tmpy; + register mp_word *_W; + register int iy; + + /* copy of left side */ + tmpx = a->dp[ix]; + + /* alias for right side */ + tmpy = a->dp + (ix + 1); + + _W = &(W[ix+ix+1]); - /* inner products */ - for (iy = ix + 1; iy < pa; iy++) { - *_W++ += ((mp_word)tmpx) * ((mp_word)*tmpy++); + /* inner products */ + for (iy = ix + 1; iy < pa; iy++) { + *_W++ += ((mp_word)tmpx) * ((mp_word)*tmpy++); + } } } @@ -1168,9 +1345,9 @@ static int fast_s_mp_sqr(mp_int *a, mp_int *b) W[ix] += W[ix] + W2[ix]; W[ix] = W[ix] + (W[ix-1] >> ((mp_word)DIGIT_BIT)); - b->dp[ix-1] = W[ix-1] & ((mp_word)MP_MASK); + b->dp[ix-1] = (mp_digit)(W[ix-1] & ((mp_word)MP_MASK)); } - b->dp[(pa+pa+1)-1] = W[(pa+pa+1)-1] & ((mp_word)MP_MASK); + b->dp[(pa+pa+1)-1] = (mp_digit)(W[(pa+pa+1)-1] & ((mp_word)MP_MASK)); /* clear high */ for (ix = b->used; ix < olduse; ix++) { @@ -1185,7 +1362,7 @@ static int fast_s_mp_sqr(mp_int *a, mp_int *b) return MP_OKAY; } -/* low level squaring, b = a*a */ +/* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */ static int s_mp_sqr(mp_int *a, mp_int *b) { mp_int t; @@ -1212,15 +1389,34 @@ static int s_mp_sqr(mp_int *a, mp_int *b) t.used = pa + pa + 1; for (ix = 0; ix < pa; ix++) { + /* first calculate the digit at 2*ix */ + /* calculate double precision result */ r = ((mp_word)t.dp[ix+ix]) + ((mp_word)a->dp[ix]) * ((mp_word)a->dp[ix]); + + /* store lower part in result */ t.dp[ix+ix] = (mp_digit)(r & ((mp_word)MP_MASK)); + + /* get the carry */ u = (r >> ((mp_word)DIGIT_BIT)); + + /* left hand side of A[ix] * A[iy] */ tmpx = a->dp[ix]; + + /* alias for where to store the results */ tmpt = &(t.dp[ix+ix+1]); for (iy = ix + 1; iy < pa; iy++) { + /* first calculate the product */ r = ((mp_word)tmpx) * ((mp_word)a->dp[iy]); + + /* now calculate the double precision result, note we use + * addition instead of *2 since its easier to optimize + */ r = ((mp_word)*tmpt) + r + r + ((mp_word)u); + + /* store lower part */ *tmpt++ = (mp_digit)(r & ((mp_word)MP_MASK)); + + /* get carry */ u = (r >> ((mp_word)DIGIT_BIT)); } r = ((mp_word)*tmpt) + u; @@ -1334,11 +1530,11 @@ int mp_sub(mp_int *a, mp_int *b, mp_int *c) return res; } -/* c = |a| * |b| using Karatsuba */ +/* c = |a| * |b| using Karatsuba Multiplication */ static int mp_karatsuba_mul(mp_int *a, mp_int *b, mp_int *c) { mp_int x0, x1, y0, y1, t1, t2, x0y0, x1y1; - int B, err, neg, x; + int B, err, x; REGFUNC("mp_karatsuba_mul"); VERIFY(a); @@ -1396,9 +1592,7 @@ static int mp_karatsuba_mul(mp_int *a, mp_int *b, mp_int *c) /* now calc x1-x0 and y1-y0 */ if (mp_sub(&x1, &x0, &t1) != MP_OKAY) goto X1Y1; /* t1 = x1 - x0 */ if (mp_sub(&y1, &y0, &t2) != MP_OKAY) goto X1Y1; /* t2 = y1 - y0 */ - neg = (t1.sign == t2.sign) ? MP_ZPOS : MP_NEG; if (mp_mul(&t1, &t2, &t1) != MP_OKAY) goto X1Y1; /* t1 = (x1 - x0) * (y1 - y0) */ - t1.sign = neg; /* add x0y0 */ if (mp_add(&x0y0, &x1y1, &t2) != MP_OKAY) goto X1Y1; /* t2 = x0y0 + x1y1 */ @@ -1538,7 +1732,15 @@ int mp_sqr(mp_int *a, mp_int *b) } -/* integer signed division. c*b + d == a [e.g. a/b, c=quotient, d=remainder] */ +/* integer signed division. c*b + d == a [e.g. a/b, c=quotient, d=remainder] + * HAC pp.598 Algorithm 14.20 + * + * Note that the description in HAC is horribly incomplete. For example, + * it doesn't consider the case where digits are removed from 'x' in the inner + * loop. It also doesn't consider the case that y has fewer than three digits, etc.. + * + * The overall algorithm is as described as 14.20 from HAC but fixed to treat these cases. +*/ int mp_div(mp_int *a, mp_int *b, mp_int *c, mp_int *d) { mp_int q, x, y, t1, t2; @@ -1596,7 +1798,7 @@ int mp_div(mp_int *a, mp_int *b, mp_int *c, mp_int *d) neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG; x.sign = y.sign = MP_ZPOS; - /* normalize */ + /* normalize both x and y, ensure that y >= b/2, [b == 2^DIGIT_BIT] */ norm = 0; while ((y.dp[y.used-1] & (((mp_digit)1)<<(DIGIT_BIT-1))) == ((mp_digit)0)) { ++norm; @@ -1927,7 +2129,7 @@ int mp_expt_d(mp_int *a, mp_digit b, mp_int *c) return res; } - if (b & (mp_digit)(1<<(DIGIT_BIT-1))) { + if ((b & (mp_digit)(1<<(DIGIT_BIT-1))) != 0) { if ((res = mp_mul(c, &g, c)) != MP_OKAY) { mp_clear(&g); DECFUNC(); @@ -2106,8 +2308,12 @@ int mp_gcd(mp_int *a, mp_int *b, mp_int *c) k = 0; while ((u.dp[0] & 1) == 0 && (v.dp[0] & 1) == 0) { ++k; - mp_div_2d(&u, 1, &u, NULL); - mp_div_2d(&v, 1, &v, NULL); + if ((res = mp_div_2d(&u, 1, &u, NULL)) != MP_OKAY) { + goto __T; + } + if ((res = mp_div_2d(&v, 1, &v, NULL)) != MP_OKAY) { + goto __T; + } } /* B2. Initialize */ @@ -2125,7 +2331,9 @@ int mp_gcd(mp_int *a, mp_int *b, mp_int *c) do { /* B3 (and B4). Halve t, if even */ while (t.used != 0 && (t.dp[0] & 1) == 0) { - mp_div_2d(&t, 1, &t, NULL); + if ((res = mp_div_2d(&t, 1, &t, NULL)) != MP_OKAY) { + goto __T; + } } /* B5. if t>0 then u=t otherwise v=-t */ @@ -2563,7 +2771,9 @@ int mp_reduce_setup(mp_int *a, mp_int *b) return res; } -/* reduces x mod m, assumes 0 < x < m^2, mu is precomputed via mp_reduce_setup */ +/* reduces x mod m, assumes 0 < x < m^2, mu is precomputed via mp_reduce_setup + * From HAC pp.604 Algorithm 14.42 + */ int mp_reduce(mp_int *x, mp_int *m, mp_int *mu) { mp_int q; @@ -2595,10 +2805,14 @@ int mp_reduce(mp_int *x, mp_int *m, mp_int *mu) mp_rshd(&q, um + 1); /* q3 = q2 / b^(k+1) */ /* x = x mod b^(k+1), quick (no division) */ - mp_mod_2d(x, DIGIT_BIT * (um + 1), x); + if ((res = mp_mod_2d(x, DIGIT_BIT * (um + 1), x)) != MP_OKAY) { + goto CLEANUP; + } /* q = q * m mod b^(k+1), quick (no division) */ - s_mp_mul_digs(&q, m, &q, um + 1); + if ((res = s_mp_mul_digs(&q, m, &q, um + 1)) != MP_OKAY) { + goto CLEANUP; + } /* x = x - q */ if((res = mp_sub(x, &q, x)) != MP_OKAY) @@ -2626,6 +2840,11 @@ int mp_reduce(mp_int *x, mp_int *m, mp_int *mu) return res; } +/* computes Y == G^X mod P, HAC pp.616, Algorithm 14.85 + * + * Uses a left-to-right k-ary sliding window to compute the modular exponentiation. + * The value of k changes based on the size of the exponent. + */ int mp_exptmod(mp_int *G, mp_int *X, mp_int *P, mp_int *Y) { mp_int M[64], res, mu; @@ -2648,7 +2867,7 @@ int mp_exptmod(mp_int *G, mp_int *X, mp_int *P, mp_int *Y) /* init G array */ for (x = 0; x < (1<dp[digidx--]; - bitcnt = DIGIT_BIT; + bitcnt = (int)DIGIT_BIT; } /* grab the next msb from the exponent */ @@ -2793,7 +3012,7 @@ int mp_exptmod(mp_int *G, mp_int *X, mp_int *P, mp_int *Y) } bitbuf <<= 1; - if (bitbuf & (1< +#include "bn.h" + +/* fast square root */ +static mp_digit i_sqrt(mp_word x) +{ + mp_word x1, x2; + + x2 = x; + do { + x1 = x2; + x2 = x1 - ((x1 * x1) - x)/(2*x1); + } while (x1 != x2); + + if (x1*x1 > x) { + --x1; + } + + return x1; +} + + +/* generates a prime digit */ +static mp_digit prime_digit() +{ + mp_digit r, x, y, next; + + /* make a DIGIT_BIT-bit random number */ + for (r = x = 0; x < DIGIT_BIT; x++) { + r = (r << 1) | (rand() & 1); + } + + /* now force it odd */ + r |= 1; + + /* force it to be >30 */ + if (r < 30) { + r += 30; + } + + /* get square root, since if 'r' is composite its factors must be < than this */ + y = i_sqrt(r); + next = (y+1)*(y+1); + + do { + r += 2; /* next candidate */ + + /* update sqrt ? */ + if (next <= r) { + ++y; + next = (y+1)*(y+1); + } + + /* loop if divisible by 3,5,7,11,13,17,19,23,29 */ + if ((r % 3) == 0) { x = 0; continue; } + if ((r % 5) == 0) { x = 0; continue; } + if ((r % 7) == 0) { x = 0; continue; } + if ((r % 11) == 0) { x = 0; continue; } + if ((r % 13) == 0) { x = 0; continue; } + if ((r % 17) == 0) { x = 0; continue; } + if ((r % 19) == 0) { x = 0; continue; } + if ((r % 23) == 0) { x = 0; continue; } + if ((r % 29) == 0) { x = 0; continue; } + + /* now check if r is divisible by x + k={1,7,11,13,17,19,23,29} */ + for (x = 30; x <= y; x += 30) { + if ((r % (x+1)) == 0) { x = 0; break; } + if ((r % (x+7)) == 0) { x = 0; break; } + if ((r % (x+11)) == 0) { x = 0; break; } + if ((r % (x+13)) == 0) { x = 0; break; } + if ((r % (x+17)) == 0) { x = 0; break; } + if ((r % (x+19)) == 0) { x = 0; break; } + if ((r % (x+23)) == 0) { x = 0; break; } + if ((r % (x+29)) == 0) { x = 0; break; } + } + } while (x == 0); + + return r; +} + +/* makes a prime of at least k bits */ +int pprime(int k, mp_int *p, mp_int *q) +{ + mp_int a, b, c, n, x, y, z, v; + int res; + + /* single digit ? */ + if (k <= (int)DIGIT_BIT) { + mp_set(p, prime_digit()); + return MP_OKAY; + } + + if ((res = mp_init(&c)) != MP_OKAY) { + return res; + } + + if ((res = mp_init(&v)) != MP_OKAY) { + goto __C; + } + + /* product of first 50 primes */ + if ((res = mp_read_radix(&v, "19078266889580195013601891820992757757219839668357012055907516904309700014933909014729740190", 10)) != MP_OKAY) { + goto __V; + } + + if ((res = mp_init(&a)) != MP_OKAY) { + goto __V; + } + + /* set the prime */ + mp_set(&a, prime_digit()); + + if ((res = mp_init(&b)) != MP_OKAY) { + goto __A; + } + + if ((res = mp_init(&n)) != MP_OKAY) { + goto __B; + } + + if ((res = mp_init(&x)) != MP_OKAY) { + goto __N; + } + + if ((res = mp_init(&y)) != MP_OKAY) { + goto __X; + } + + if ((res = mp_init(&z)) != MP_OKAY) { + goto __Y; + } + + /* now loop making the single digit */ + while (mp_count_bits(&a) < k) { + printf("prime is %4d bits left\r", k - mp_count_bits(&a)); fflush(stdout); + top: + mp_set(&b, prime_digit()); + + /* now compute z = a * b * 2 */ + if ((res = mp_mul(&a, &b, &z)) != MP_OKAY) { /* z = a * b */ + goto __Z; + } + + if ((res = mp_copy(&z, &c)) != MP_OKAY) { /* c = a * b */ + goto __Z; + } + + if ((res = mp_mul_2(&z, &z)) != MP_OKAY) { /* z = 2 * a * b */ + goto __Z; + } + + /* n = z + 1 */ + if ((res = mp_add_d(&z, 1, &n)) != MP_OKAY) { /* n = z + 1 */ + goto __Z; + } + + /* check (n, v) == 1 */ + if ((res = mp_gcd(&n, &v, &y)) != MP_OKAY) { /* y = (n, v) */ + goto __Z; + } + + if (mp_cmp_d(&y, 1) != MP_EQ) goto top; + + /* now try base x=2 */ + mp_set(&x, 2); + + /* compute x^a mod n */ + if ((res = mp_exptmod(&x, &a, &n, &y)) != MP_OKAY) { /* y = x^a mod n */ + goto __Z; + } + + /* if y == 1 loop */ + if (mp_cmp_d(&y, 1) == MP_EQ) goto top; + + /* now x^2a mod n */ + if ((res = mp_sqrmod(&y, &n, &y)) != MP_OKAY) { /* y = x^2a mod n */ + goto __Z; + } + + if (mp_cmp_d(&y, 1) == MP_EQ) goto top; + + /* compute x^b mod n */ + if ((res = mp_exptmod(&x, &b, &n, &y)) != MP_OKAY) { /* y = x^b mod n */ + goto __Z; + } + + /* if y == 1 loop */ + if (mp_cmp_d(&y, 1) == MP_EQ) goto top; + + /* now x^2b mod n */ + if ((res = mp_sqrmod(&y, &n, &y)) != MP_OKAY) { /* y = x^2b mod n */ + goto __Z; + } + + if (mp_cmp_d(&y, 1) == MP_EQ) goto top; + + + /* compute x^c mod n == x^ab mod n */ + if ((res = mp_exptmod(&x, &c, &n, &y)) != MP_OKAY) { /* y = x^ab mod n */ + goto __Z; + } + + /* if y == 1 loop */ + if (mp_cmp_d(&y, 1) == MP_EQ) goto top; + + /* now compute (x^c mod n)^2 */ + if ((res = mp_sqrmod(&y, &n, &y)) != MP_OKAY) { /* y = x^2ab mod n */ + goto __Z; + } + + /* y should be 1 */ + if (mp_cmp_d(&y, 1) != MP_EQ) goto top; + +/* +{ + char buf[4096]; + + mp_toradix(&n, buf, 10); + printf("Certificate of primality for:\n%s\n\n", buf); + mp_toradix(&a, buf, 10); + printf("A == \n%s\n\n", buf); + mp_toradix(&b, buf, 10); + printf("B == \n%s\n", buf); + printf("----------------------------------------------------------------\n"); +} +*/ + /* a = n */ + mp_copy(&n, &a); + } + + mp_exch(&n, p); + mp_exch(&b, q); + + res = MP_OKAY; +__Z: mp_clear(&z); +__Y: mp_clear(&y); +__X: mp_clear(&x); +__N: mp_clear(&n); +__B: mp_clear(&b); +__A: mp_clear(&a); +__V: mp_clear(&v); +__C: mp_clear(&c); + return res; +} + + +int main(void) +{ + mp_int p, q; + char buf[4096]; + int k; + clock_t t1; + + srand(time(NULL)); + + printf("Enter # of bits: \n"); + scanf("%d", &k); + + mp_init(&p); + mp_init(&q); + + t1 = clock(); + pprime(k, &p, &q); + t1 = clock() - t1; + + printf("\n\nTook %ld ticks, %d bits\n", t1, mp_count_bits(&p)); + + mp_toradix(&p, buf, 10); + printf("P == %s\n", buf); + mp_toradix(&q, buf, 10); + printf("Q == %s\n", buf); + + return 0; +} + diff --git a/makefile b/makefile index 95c6465..cbb5ac7 100644 --- a/makefile +++ b/makefile @@ -1,14 +1,19 @@ CC = gcc -CFLAGS += -DDEBUG -Wall -W -O3 -fomit-frame-pointer -funroll-loops +CFLAGS += -Wall -W -O3 -fomit-frame-pointer -funroll-loops -VERSION=0.07 +VERSION=0.08 default: test -test: bn.o demo.o +test: bn.o demo.o $(CC) bn.o demo.o -o demo cd mtest ; gcc -O3 -fomit-frame-pointer -funroll-loops mtest.c -o mtest.exe -s +# builds the x86 demo +test86: + nasm -f coff timer.asm + $(CC) -DDEBUG -DTIMER_X86 $(CFLAGS) bn.c demo.c timer.o -o demo -s + docdvi: bn.tex latex bn @@ -17,7 +22,7 @@ docs: docdvi rm -f bn.log bn.aux bn.dvi clean: - rm -f *.o *.exe mtest/*.exe bn.log bn.aux bn.dvi *.s + rm -f *.pdf *.o *.exe mtest/*.exe etc/*.exe bn.log bn.aux bn.dvi *.s zipup: clean docs chdir .. ; rm -rf ltm* libtommath-$(VERSION) ; mkdir libtommath-$(VERSION) ; \ diff --git a/mtest/mtest.c b/mtest/mtest.c index 576feb2..de04e2b 100644 --- a/mtest/mtest.c +++ b/mtest/mtest.c @@ -41,7 +41,7 @@ void rand_num(mp_int *a) unsigned char buf[512]; top: - size = 1 + ((fgetc(rng)*fgetc(rng)) % 512); + size = 1 + ((fgetc(rng)*fgetc(rng)) % 32); buf[0] = (fgetc(rng)&1)?1:0; fread(buf+1, 1, size, rng); for (n = 0; n < size; n++) { @@ -57,7 +57,7 @@ void rand_num2(mp_int *a) unsigned char buf[512]; top: - size = 1 + ((fgetc(rng)*fgetc(rng)) % 512); + size = 1 + ((fgetc(rng)*fgetc(rng)) % 32); buf[0] = (fgetc(rng)&1)?1:0; fread(buf+1, 1, size, rng); for (n = 0; n < size; n++) { @@ -80,6 +80,13 @@ int main(void) mp_init(&e); rng = fopen("/dev/urandom", "rb"); + if (rng == NULL) { + rng = fopen("/dev/random", "rb"); + if (rng == NULL) { + fprintf(stderr, "\nWarning: stdin used as random source\n\n"); + rng = stdin; + } + } for (;;) { n = fgetc(rng) % 11; diff --git a/timer.asm b/timer.asm index e8b6383..2393250 100644 --- a/timer.asm +++ b/timer.asm @@ -9,6 +9,12 @@ [section .data] timer dd 0, 0 [section .text] + +[global _gettsc] +_gettsc: + rdtsc + ret + [global _rdtsc] _rdtsc: rdtsc