//we don't have to bother with making sure that gcd(R,M) == 1 since M is odd. uberzahl modexp_mm(mm_t & mm, uberzahl base, uberzahl exp, uberzahl M){ if(!mm.initialized){ mm.R = next_power(M); mm.Rbits = mm.R.bitLength(); mm.Mprime = (mm.R-M.inverse(mm.R)); uberzahl z("1"); uberzahl t("2"); mm.Rsq = modexp(mm.R,t,M); //mm.z_init = mm.R % M; mm.z_init = montgomery_reduction(mm.Rsq, M, mm.Mprime, mm.Rbits, mm.R); mm.initialized = true; } //convert into Montgomery space uberzahl z = mm.z_init; //According to Piazza post we don't even need to calculate the residues with mod if(base * mm.Rsq < mm.R*M) base = montgomery_reduction(base * mm.Rsq, M, mm.Mprime, mm.Rbits, mm.R); else base = base * mm.R % M; mediumType i = exp.bitLength() - 1; while(i >= 0) { z = montgomery_reduction(z * z, M, mm.Mprime, mm.Rbits, mm.R); if(exp.bit(i) == 1){ z = montgomery_reduction(z * base , M, mm.Mprime, mm.Rbits, mm.R); } if(i == 0) break; i -= 1; } return montgomery_reduction(z, M, mm.Mprime, mm.Rbits, mm.R); }
uberzahl originalModExp(uberzahl c, uberzahl a, uberzahl p, uberzahl q){ //a^c mod pq auto start = chrono::steady_clock::now(); uberzahl z = 1; uberzahl n = p*q; unsigned int numBits = c.bitLength(); unsigned int currentBit = 1; for (unsigned int i = 0; i < numBits; i++){ z = (z*z) % n; currentBit = c.bit(numBits-i-1); if (currentBit == 1) z = (z*a) % n; } if(q > 1){ auto current = chrono::steady_clock::now(); auto elapsed = chrono::duration_cast<chrono::duration<double>>(current-start); double chrono_time = elapsed.count(); cerr << "\tSqMultOrig time: " << chrono_time << "\n"; } return z; }
uberzahl multiply (uberzahl x, uberzahl y) { uberzahl z, v = y, r = "299076299051606071403356588563077529600"; for (int i = 127; i >= 0; --i) { // Set z block if (x.bit(i) == 0) { z = z; } else { z = z ^ v; } // Set v block if (v.bit(0) == 0) { v = v >> 1; } else {
uberzahl modexp(uberzahl base, uberzahl exp, uberzahl n){ mediumType i = exp.bitLength() - 1; uberzahl z((largeType)1); while(i >= 0){ z = (z * z ) % n; if(exp.bit(i) == 1) z = (z * base) % n; if(i == 0) break; i -= 1; } return z; }