bool is_one_zero_one_poly(const uint64_t *poly, int coeff_count, int coeff_uint64_count) { #ifdef _DEBUG if (poly == nullptr && coeff_count > 0 && coeff_uint64_count > 0) { throw invalid_argument("poly"); } if (coeff_count < 0) { throw invalid_argument("coeff_count"); } if (coeff_uint64_count < 0) { throw invalid_argument("coeff_uint64_count"); } #endif if (coeff_count == 0 || coeff_uint64_count == 0) { return false; } if (!is_equal_uint(get_poly_coeff(poly, 0, coeff_uint64_count), coeff_uint64_count, 1)) { return false; } if (!is_equal_uint(get_poly_coeff(poly, coeff_count - 1, coeff_uint64_count), coeff_uint64_count, 1)) { return false; } if (coeff_count > 2 && !is_zero_poly(poly + coeff_uint64_count, coeff_count - 2, coeff_uint64_count)) { return false; } return true; }
void exponentiate_poly(const std::uint64_t *poly, int poly_coeff_count, int poly_coeff_uint64_count, const uint64_t *exponent, int exponent_uint64_count, int result_coeff_count, int result_coeff_uint64_count, std::uint64_t *result, MemoryPool &pool) { #ifdef SEAL_DEBUG if (poly == nullptr) { throw invalid_argument("poly"); } if (poly_coeff_count <= 0) { throw invalid_argument("poly_coeff_count"); } if (poly_coeff_count <= 0) { throw invalid_argument("poly_coeff_uint64_count"); } if (exponent == nullptr) { throw invalid_argument("exponent"); } if (exponent_uint64_count <= 0) { throw invalid_argument("exponent_uint64_count"); } if (result == nullptr) { throw invalid_argument("result"); } if (result_coeff_count <= 0) { throw invalid_argument("result_coeff_count"); } if (result_coeff_uint64_count <= 0) { throw invalid_argument("result_coeff_uint64_count"); } #endif // Fast cases if (is_zero_uint(exponent, exponent_uint64_count)) { set_zero_poly(result_coeff_count, result_coeff_uint64_count, result); *result = 1; return; } if (is_equal_uint(exponent, exponent_uint64_count, 1)) { set_poly_poly(poly, poly_coeff_count, poly_coeff_uint64_count, result_coeff_count, result_coeff_uint64_count, result); return; } // Need to make a copy of exponent Pointer exponent_copy(allocate_uint(exponent_uint64_count, pool)); set_uint_uint(exponent, exponent_uint64_count, exponent_copy.get()); // Perform binary exponentiation. Pointer big_alloc(allocate_uint((static_cast<int64_t>(result_coeff_count) + result_coeff_count + result_coeff_count) * result_coeff_uint64_count, pool)); uint64_t *powerptr = big_alloc.get(); uint64_t *productptr = get_poly_coeff(powerptr, result_coeff_count, result_coeff_uint64_count); uint64_t *intermediateptr = get_poly_coeff(productptr, result_coeff_count, result_coeff_uint64_count); set_poly_poly(poly, poly_coeff_count, poly_coeff_uint64_count, result_coeff_count, result_coeff_uint64_count, powerptr); set_zero_poly(result_coeff_count, result_coeff_uint64_count, intermediateptr); *intermediateptr = 1; // Initially: power = operand and intermediate = 1, product is not initialized. while (true) { if ((*exponent_copy.get() % 2) == 1) { multiply_poly_poly(powerptr, result_coeff_count, result_coeff_uint64_count, intermediateptr, result_coeff_count, result_coeff_uint64_count, result_coeff_count, result_coeff_uint64_count, productptr, pool); swap(productptr, intermediateptr); } right_shift_uint(exponent_copy.get(), 1, exponent_uint64_count, exponent_copy.get()); if (is_zero_uint(exponent_copy.get(), exponent_uint64_count)) { break; } multiply_poly_poly(powerptr, result_coeff_count, result_coeff_uint64_count, powerptr, result_coeff_count, result_coeff_uint64_count, result_coeff_count, result_coeff_uint64_count, productptr, pool); swap(productptr, powerptr); } set_poly_poly(intermediateptr, result_coeff_count, result_coeff_uint64_count, result); }