/*------------------------------------------------------------------------*/ static uint32 lift_root_32(uint32 n, uint32 r, uint32 old_power, uint32 p, uint32 d) { uint32 q; uint32 p2 = old_power * p; uint64 rsave = r; q = mp_modsub_1(n % p2, mp_expo_1(r, d, p2), p2) / old_power; r = mp_modmul_1(d, mp_expo_1(r % p, d - 1, p), p); r = mp_modmul_1(q, mp_modinv_1(r, p), p); return rsave + old_power * r; }
/*------------------------------------------------------------------------*/ static uint32 lift_root_32(uint32 n, uint32 r, uint32 old_power, uint32 p, uint32 d) { /* given r, a d_th root of n mod old_power, compute the corresponding root mod (old_power*p) via Hensel lifting */ uint32 q; uint32 p2 = old_power * p; uint64 rsave = r; q = mp_modsub_1(n % p2, mp_expo_1(r, d, p2), p2) / old_power; r = mp_modmul_1(d, mp_expo_1(r % p, d - 1, p), p); r = mp_modmul_1(q, mp_modinv_1(r, p), p); return rsave + old_power * r; }
static inline u_int32_t mp_modsqrt_1(u_int32_t a, u_int32_t p) { u_int32_t a0 = a; if((p & 7) == 3 || (p & 7) == 7) { return mp_expo_1(a0, (p+1)/4, p); } else if((p & 7) == 5) { u_int32_t x, y; if(a0 >= p) a0 = a0 % p; x = mp_expo_1(a0, (p+3)/8, p); if(mp_modmul_1(x, x, p) == a0) return x; y = mp_expo_1(2, (p-1)/4, p); return mp_modmul_1(x, y, p); } else { u_int32_t d0, d1, a1, s, t, m; u_int32_t i; if(a0 == 1) return 1; for(d0 = 2; d0 < p; d0++) { if(mp_legendre_1(d0, p) != -1) continue; t = p - 1; s = 0; while(!(t & 1)) { s++; t = t / 2; } a1 = mp_expo_1(a0, t, p); d1 = mp_expo_1(d0, t, p); for(i = 0, m = 0; i < s; i++) { u_int32_t ad; ad = mp_expo_1(d1, m, p); ad = mp_modmul_1(ad, a1, p); ad = mp_expo_1(ad, (u_int32_t)(1) << (s-1-i), p); if(ad == (p - 1)) m += (1 << i); } a1 = mp_expo_1(a0, (t+1)/2, p); d1 = mp_expo_1(d1, m/2, p); return mp_modmul_1(a1, d1, p); } } printf("modsqrt_1 failed\n"); exit(-1); return 0; }