示例#1
0
    void poly_eval_poly(const BigPoly &poly_to_evaluate, const BigPoly &poly_to_evaluate_at, 
        BigPoly &destination, const MemoryPoolHandle &pool)
    {
        if (!pool)
        {
            throw invalid_argument("pool is uninitialized");
        }

        int poly_to_eval_coeff_uint64_count = divide_round_up(poly_to_evaluate.coeff_bit_count(), bits_per_uint64);
        int value_coeff_uint64_count = divide_round_up(poly_to_evaluate_at.coeff_bit_count(), bits_per_uint64);

        if (poly_to_evaluate.is_zero())
        {
            destination.set_zero();
            return;
        }
        if (poly_to_evaluate_at.is_zero())
        {
            destination.resize(1, poly_to_evaluate.coeff_bit_count());
            set_uint_uint(poly_to_evaluate.data(), poly_to_eval_coeff_uint64_count, destination.data());
            return;
        }

        int result_coeff_count = (poly_to_evaluate.significant_coeff_count() - 1) * (poly_to_evaluate_at.significant_coeff_count() - 1) + 1;
        int result_coeff_bit_count = poly_to_evaluate.coeff_bit_count() + (poly_to_evaluate.coeff_count() - 1) * poly_to_evaluate_at.coeff_bit_count();
        int result_coeff_uint64_count = divide_round_up(result_coeff_bit_count, bits_per_uint64);
        destination.resize(result_coeff_count, result_coeff_bit_count);

        util::poly_eval_poly(poly_to_evaluate.data(), poly_to_evaluate.coeff_count(), poly_to_eval_coeff_uint64_count, 
            poly_to_evaluate_at.data(), poly_to_evaluate_at.coeff_count(), value_coeff_uint64_count, 
            result_coeff_count, result_coeff_uint64_count, destination.data(), pool);
    }
    ChooserPoly ChooserEvaluator::exponentiate(const ChooserPoly &operand, uint64_t exponent)
    {
        if (operand.max_coeff_count_ <= 0 || operand.comp_ == nullptr)
        {
            throw invalid_argument("operand is not correctly initialized");
        }
        if (exponent == 0 && operand.max_abs_value_.is_zero())
        {
            throw invalid_argument("undefined operation");
        }
        if (exponent == 0)
        {
            return ChooserPoly(1, 1, new ExponentiateComputation(*operand.comp_, exponent));
        }
        if (operand.max_abs_value_.is_zero())
        {
            return ChooserPoly(1, 0, new ExponentiateComputation(*operand.comp_, exponent));
        }

        // There is no known closed formula for the growth factor, but we use the asymptotic approximation
        // k^n * sqrt[6/((k-1)*(k+1)*Pi*n)], where k = max_coeff_count_, n = exponent.
        uint64_t growth_factor = static_cast<uint64_t>(pow(operand.max_coeff_count_, exponent) * sqrt(6 / ((operand.max_coeff_count_ - 1) * (operand.max_coeff_count_ + 1) * 3.1415 * exponent)));

        int result_bit_count = static_cast<int>(exponent) * operand.max_abs_value_.significant_bit_count() + get_significant_bit_count(growth_factor) + 1;
        int result_uint64_count = divide_round_up(result_bit_count, bits_per_uint64);

        Pointer result_max_abs_value(allocate_uint(result_uint64_count, pool_));

        util::exponentiate_uint(operand.max_abs_value_.pointer(), operand.max_abs_value_.uint64_count(), &exponent, 1, result_uint64_count, result_max_abs_value.get(), pool_);

        ConstPointer temp_pointer(duplicate_uint_if_needed(result_max_abs_value.get(), result_uint64_count, result_uint64_count, true, pool_));
        multiply_uint_uint(&growth_factor, 1, temp_pointer.get(), result_uint64_count, result_uint64_count, result_max_abs_value.get());

        return ChooserPoly(static_cast<int>(exponent) * (operand.max_coeff_count_ - 1) + 1, BigUInt(result_bit_count, result_max_abs_value.get()), new ExponentiateComputation(*operand.comp_, exponent));
    }
    void HkeyGen::compute_secret_key_array(int max_power)
    {
        if (max_power < 1)
        {
            throw invalid_argument("max_power cannot be less than 1");
        }

        int old_count = secret_key_array_.size();
        int new_count = max(max_power, secret_key_array_.size());

        if (old_count == new_count)
        {
            return;
        }

        int coeff_count = poly_modulus_.coeff_count();
        int coeff_bit_count = coeff_modulus_.bit_count();
        int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);

        // Compute powers of secret key until max_power
        secret_key_array_.resize(new_count, coeff_count, coeff_bit_count);

        MemoryPool &pool = *MemoryPool::default_pool();

        int poly_ptr_increment = coeff_count * coeff_uint64_count;
        uint64_t *prev_poly_ptr = secret_key_array_.pointer(old_count - 1);
        uint64_t *next_poly_ptr = prev_poly_ptr + poly_ptr_increment;
        for (int i = old_count; i < new_count; ++i)
        {
            multiply_poly_poly_polymod_coeffmod(prev_poly_ptr, secret_key_.pointer(), polymod_, mod_, next_poly_ptr, pool);
            prev_poly_ptr = next_poly_ptr;
            next_poly_ptr += poly_ptr_increment;
        }
    }
示例#4
0
 void KeyGenerator::set_poly_coeffs_normal(uint64_t *poly) const
 {
     int coeff_count = poly_modulus_.coeff_count();
     int coeff_bit_count = poly_modulus_.coeff_bit_count();
     int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);
     if (noise_standard_deviation_ == 0 || noise_max_deviation_ == 0)
     {
         set_zero_poly(coeff_count, coeff_uint64_count, poly);
         return;
     }
     RandomToStandardAdapter engine(random_generator_.get());
     ClippedNormalDistribution random(0, noise_standard_deviation_, noise_max_deviation_);
     for (int i = 0; i < coeff_count - 1; ++i)
     {
         int64_t noise = static_cast<int64_t>(random(engine));
         if (noise > 0)
         {
             set_uint(static_cast<uint64_t>(noise), coeff_uint64_count, poly);
         }
         else if (noise < 0)
         {
             noise = -noise;
             set_uint(static_cast<uint64_t>(noise), coeff_uint64_count, poly);
             sub_uint_uint(coeff_modulus_.pointer(), poly, coeff_uint64_count, poly);
         }
         else
         {
             set_zero_uint(coeff_uint64_count, poly);
         }
         poly += coeff_uint64_count;
     }
     set_zero_uint(coeff_uint64_count, poly);
 }
示例#5
0
 void KeyGenerator::set_poly_coeffs_zero_one_negone(uint64_t *poly) const
 {
     int coeff_count = poly_modulus_.coeff_count();
     int coeff_bit_count = poly_modulus_.coeff_bit_count();
     int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);
     RandomToStandardAdapter engine(random_generator_.get());
     uniform_int_distribution<int> random(-1, 1);
     for (int i = 0; i < coeff_count - 1; ++i)
     {
         int rand_index = random(engine);
         if (rand_index == 1)
         {
             set_uint(1, coeff_uint64_count, poly);
         }
         else if (rand_index == -1)
         {
             set_uint_uint(coeff_modulus_minus_one_.pointer(), coeff_uint64_count, poly);
         }
         else
         {
             set_zero_uint(coeff_uint64_count, poly);
         }
         poly += coeff_uint64_count;
     }
     set_zero_uint(coeff_uint64_count, poly);
 }
    void Evaluator::relinearize(const uint64_t *encrypted, uint64_t *destination)
    {
        // Extract encryption parameters.
        int coeff_count = poly_modulus_.coeff_count();
        int coeff_bit_count = poly_modulus_.coeff_bit_count();
        int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);

        // Clear destatintion.
        set_zero_poly(coeff_count, coeff_uint64_count, destination);

        // Create polynomial to store decomposed polynomial (one at a time).
        Pointer decomp_poly(allocate_poly(coeff_count, coeff_uint64_count, pool_));
        Pointer decomp_eval_poly(allocate_poly(coeff_count, coeff_uint64_count, pool_));
        int shift = 0;
        for (int decomp_index = 0; decomp_index < evaluation_keys_.count(); ++decomp_index)
        {
            // Isolate decomposition_bit_count_ bits for each coefficient.
            for (int coeff_index = 0; coeff_index < coeff_count; ++coeff_index)
            {
                const uint64_t *productmoded_coeff = get_poly_coeff(encrypted, coeff_index, coeff_uint64_count);
                uint64_t *decomp_coeff = get_poly_coeff(decomp_poly.get(), coeff_index, coeff_uint64_count);
                right_shift_uint(productmoded_coeff, shift, coeff_uint64_count, decomp_coeff);
                filter_highbits_uint(decomp_coeff, coeff_uint64_count, decomposition_bit_count_);
            }

            // Multiply decomposed poly by evaluation key and accumulate to result.
            const BigPoly &evaluation_key = evaluation_keys_[decomp_index];
            multiply_poly_poly_polymod_coeffmod(decomp_poly.get(), evaluation_key.pointer(), polymod_, mod_, decomp_eval_poly.get(), pool_);
            add_poly_poly_coeffmod(decomp_eval_poly.get(), destination, coeff_count, coeff_modulus_.pointer(), coeff_uint64_count, destination);

            // Increase shift by decomposition_bit_count_ for next iteration.
            shift += decomposition_bit_count_;
        }
    }
    void Evaluator::negate(const BigPoly &encrypted, BigPoly &destination)
    {
        // Extract encryption parameters.
        int coeff_count = poly_modulus_.coeff_count();
        int coeff_bit_count = poly_modulus_.coeff_bit_count();
        int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);

        // Verify parameters.
        if (encrypted.coeff_count() != coeff_count || encrypted.coeff_bit_count() != coeff_bit_count)
        {
            throw invalid_argument("encrypted is not valid for encryption parameters");
        }
#ifdef _DEBUG
        if (encrypted.significant_coeff_count() == coeff_count || !are_poly_coefficients_less_than(encrypted, coeff_modulus_))
        {
            throw invalid_argument("encrypted is not valid for encryption parameters");
        }
#endif
        if (destination.coeff_count() != coeff_count || destination.coeff_bit_count() != coeff_bit_count)
        {
            destination.resize(coeff_count, coeff_bit_count);
        }

        // Handle test-mode case.
        if (mode_ == TEST_MODE)
        {
            negate_poly_coeffmod(encrypted.pointer(), coeff_count, plain_modulus_.pointer(), coeff_uint64_count, destination.pointer());
            return;
        }

        // Negate polynomial.
        negate_poly_coeffmod(encrypted.pointer(), coeff_count, coeff_modulus_.pointer(), coeff_uint64_count, destination.pointer());
    }
    ChooserPoly ChooserEvaluator::multiply_plain(const ChooserPoly &operand, int plain_max_coeff_count, const BigUInt &plain_max_abs_value)
    {
        if (operand.max_coeff_count_ <= 0 || operand.comp_ == nullptr)
        {
            throw invalid_argument("operand is not correctly initialized");
        }
        if (plain_max_coeff_count <= 0)
        {
            throw invalid_argument("plain_max_coeff_count must be positive");
        }
        if (plain_max_abs_value.is_zero())
        {
            return ChooserPoly(1, 0, new MultiplyPlainComputation(*operand.comp_, plain_max_coeff_count, plain_max_abs_value));
        }
        if (operand.max_abs_value_.is_zero())
        {
            return ChooserPoly(1, 0, new MultiplyPlainComputation(*operand.comp_, plain_max_coeff_count, plain_max_abs_value));
        }

        uint64_t growth_factor = min(operand.max_coeff_count_, plain_max_coeff_count);
        int prod_bit_count = operand.max_abs_value_.significant_bit_count() + plain_max_abs_value.significant_bit_count() + get_significant_bit_count(growth_factor) + 1;
        int prod_uint64_count = divide_round_up(prod_bit_count, bits_per_uint64);

        Pointer prod_max_abs_value(allocate_zero_uint(prod_uint64_count, pool_));
        ConstPointer wide_operand_max_abs_value(duplicate_uint_if_needed(operand.max_abs_value_.pointer(), operand.max_abs_value_.uint64_count(), prod_uint64_count, false, pool_));

        multiply_uint_uint(&growth_factor, 1, plain_max_abs_value.pointer(), plain_max_abs_value.uint64_count(), prod_uint64_count, prod_max_abs_value.get());
        ConstPointer temp_pointer(duplicate_uint_if_needed(prod_max_abs_value.get(), prod_uint64_count, prod_uint64_count, true, pool_));
        multiply_uint_uint(wide_operand_max_abs_value.get(), prod_uint64_count, temp_pointer.get(), prod_uint64_count, prod_uint64_count, prod_max_abs_value.get());

        return ChooserPoly(operand.max_coeff_count_ + plain_max_coeff_count - 1, BigUInt(prod_bit_count, prod_max_abs_value.get()), new MultiplyPlainComputation(*operand.comp_, plain_max_coeff_count, plain_max_abs_value));
    }
示例#9
0
    BigUInt poly_infty_norm_coeffmod(const BigPoly &poly, const BigUInt &modulus, const MemoryPoolHandle &pool)
    {
        if (modulus.is_zero())
        {
            throw invalid_argument("modulus cannot be zero");
        }
        if (!pool)
        {
            throw invalid_argument("pool is uninitialized");
        }

        if (poly.is_zero())
        {
            return BigUInt();
        }

        int poly_coeff_count = poly.coeff_count();
        int poly_coeff_bit_count = poly.coeff_bit_count();
        int poly_coeff_uint64_count = divide_round_up(poly_coeff_bit_count, bits_per_uint64);

        Modulus mod(modulus.data(), modulus.uint64_count(), pool);
        BigUInt result(modulus.significant_bit_count());
        util::poly_infty_norm_coeffmod(poly.data(), poly_coeff_count, poly_coeff_uint64_count, mod, result.data(), pool);

        return result;
    }
    void Evaluator::preencrypt(const uint64_t *plain, int plain_coeff_count, int plain_coeff_uint64_count, uint64_t *destination)
    {
        // Extract encryption parameters.
        int coeff_count = poly_modulus_.coeff_count();
        int coeff_bit_count = poly_modulus_.coeff_bit_count();
        int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);

        // Only care about coefficients up till coeff_count.
        if (plain_coeff_count > coeff_count)
        {
            plain_coeff_count = coeff_count;
        }

        // Multiply plain by scalar coeff_div_plaintext and reposition if in upper-half.
        if (plain == destination)
        {
            // If plain and destination are same poly, then need another storage for multiply output.
            Pointer temp(allocate_uint(coeff_uint64_count, pool_));
            for (int i = 0; i < plain_coeff_count; ++i)
            {
                multiply_uint_uint(plain, plain_coeff_uint64_count, coeff_div_plain_modulus_.pointer(), coeff_uint64_count, coeff_uint64_count, temp.get());
                bool is_upper_half = is_greater_than_or_equal_uint_uint(temp.get(), upper_half_threshold_.pointer(), coeff_uint64_count);
                if (is_upper_half)
                {
                    add_uint_uint(temp.get(), upper_half_increment_.pointer(), coeff_uint64_count, destination);
                }
                else
                {
                    set_uint_uint(temp.get(), coeff_uint64_count, destination);
                }
                plain += plain_coeff_uint64_count;
                destination += coeff_uint64_count;
            }
        }
        else
        {
            for (int i = 0; i < plain_coeff_count; ++i)
            {
                multiply_uint_uint(plain, plain_coeff_uint64_count, coeff_div_plain_modulus_.pointer(), coeff_uint64_count, coeff_uint64_count, destination);
                bool is_upper_half = is_greater_than_or_equal_uint_uint(destination, upper_half_threshold_.pointer(), coeff_uint64_count);
                if (is_upper_half)
                {
                    add_uint_uint(destination, upper_half_increment_.pointer(), coeff_uint64_count, destination);
                }
                plain += plain_coeff_uint64_count;
                destination += coeff_uint64_count;
            }
        }

        // Zero any remaining coefficients.
        for (int i = plain_coeff_count; i < coeff_count; ++i)
        {
            set_zero_uint(coeff_uint64_count, destination);
            destination += coeff_uint64_count;
        }
    }
    void Evaluator::sub_plain(const BigPoly &encrypted1, const BigPoly &plain2, BigPoly &destination)
    {
        // Extract encryption parameters.
        int coeff_count = poly_modulus_.coeff_count();
        int coeff_bit_count = poly_modulus_.coeff_bit_count();
        int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);

        // Verify parameters.
        if (encrypted1.coeff_count() != coeff_count || encrypted1.coeff_bit_count() != coeff_bit_count)
        {
            throw invalid_argument("encrypted1 is not valid for encryption parameters");
        }
#ifdef _DEBUG
        if (encrypted1.significant_coeff_count() == coeff_count || !are_poly_coefficients_less_than(encrypted1, coeff_modulus_))
        {
            throw invalid_argument("encrypted1 is not valid for encryption parameters");
        }
        if (plain2.significant_coeff_count() >= coeff_count || !are_poly_coefficients_less_than(plain2, plain_modulus_))
        {
            throw invalid_argument("plain2 is too large to be represented by encryption parameters");
        }
#endif
        if (destination.coeff_count() != coeff_count || destination.coeff_bit_count() != coeff_bit_count)
        {
            destination.resize(coeff_count, coeff_bit_count);
        }

        int plain2_coeff_uint64_count = divide_round_up(plain2.coeff_bit_count(), bits_per_uint64);
        if (mode_ == TEST_MODE)
        {
            // Handle test-mode case.
            set_poly_poly(plain2.pointer(), plain2.coeff_count(), plain2_coeff_uint64_count, coeff_count, coeff_uint64_count, destination.pointer());
            modulo_poly_coeffs(destination.pointer(), coeff_count, mod_, pool_);
            sub_poly_poly_coeffmod(encrypted1.pointer(), destination.pointer(), coeff_count, plain_modulus_.pointer(), coeff_uint64_count, destination.pointer());
            return;
        }

        // Multiply plain by scalar coeff_div_plaintext and reposition if in upper-half.
        preencrypt(plain2.pointer(), plain2.coeff_count(), plain2_coeff_uint64_count, destination.pointer());

        // Subtract encrypted polynomial and encrypted-version of plain2.
        sub_poly_poly_coeffmod(encrypted1.pointer(), destination.pointer(), coeff_count, coeff_modulus_.pointer(), coeff_uint64_count, destination.pointer());
    }
示例#12
0
page_rendering_t pager_t::render() const
{

    /**
       Try to print the completions. Start by trying to print the
       list in PAGER_MAX_COLS columns, if the completions won't
       fit, reduce the number of columns by one. Printing a single
       column never fails.
    */
    page_rendering_t rendering;
    rendering.term_width = this->available_term_width;
    rendering.term_height = this->available_term_height;
    rendering.search_field_shown = this->search_field_shown;
    rendering.search_field_line = this->search_field_line;

    for (int cols = PAGER_MAX_COLS; cols > 0; cols--)
    {
        /* Initially empty rendering */
        rendering.screen_data.resize(0);

        /* Determine how many rows we would need if we had 'cols' columns. Then determine how many columns we want from that. For example, say we had 19 completions. We can fit them into 6 columns, 4 rows, with the last row containing only 1 entry. Or we can fit them into 5 columns, 4 rows, the last row containing 4 entries. Since fewer columns with the same number of rows is better, skip cases where we know we can do better. */
        size_t min_rows_required_for_cols = divide_round_up(completion_infos.size(), cols);
        size_t min_cols_required_for_rows = divide_round_up(completion_infos.size(), min_rows_required_for_cols);

        assert(min_cols_required_for_rows <= cols);
        if (cols > 1 && min_cols_required_for_rows < cols)
        {
            /* Next iteration will be better, so skip this one */
            continue;
        }

        rendering.cols = (size_t)cols;
        rendering.rows = min_rows_required_for_cols;
        rendering.selected_completion_idx = this->visual_selected_completion_index(rendering.rows, rendering.cols);

        if (completion_try_print(cols, prefix, completion_infos, &rendering, suggested_row_start))
        {
            break;
        }
    }
    return rendering;
}
示例#13
0
		/*
			Returns how deep node hiearchy is for a tree with *count* values. Counts both leaf-nodes and inodes.
			0: empty
			1: one leaf node.
			2: one inode with 1-4 leaf nodes.
			3: two levels of inodes plus leaf nodes.
		*/
		inline int count_to_depth(size_t count){
			const auto leaf_count = divide_round_up(count, BRANCHING_FACTOR);

			if(leaf_count == 0){
				return 0;
			}
			else if(leaf_count == 1){
				return 1;
			}
			else {
				return 1 + count_to_depth(leaf_count);
			}
		}
    void Evaluator::exponentiate_norelin(const BigPoly &encrypted, int exponent, BigPoly &destination)
    {
        // Extract encryption parameters.
        int coeff_count = poly_modulus_.coeff_count();
        int coeff_bit_count = poly_modulus_.coeff_bit_count();
        int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);

        // Verify parameters.
        if (encrypted.coeff_count() != coeff_count || encrypted.coeff_bit_count() != coeff_bit_count)
        {
            throw invalid_argument("encrypted is not valid for encryption parameters");
        }
#ifdef _DEBUG
        if (encrypted.significant_coeff_count() == coeff_count || !are_poly_coefficients_less_than(encrypted, coeff_modulus_))
        {
            throw invalid_argument("encrypted is not valid for encryption parameters");
        }
#endif
        if (exponent < 0)
        {
            throw invalid_argument("exponent must be non-negative");
        }

        if (exponent == 0)
        {
            if (destination.coeff_count() != coeff_count || destination.coeff_bit_count() != coeff_bit_count)
            {
                destination.resize(coeff_count, coeff_bit_count);
            }
            set_uint_uint(coeff_div_plain_modulus_.pointer(), coeff_uint64_count, destination.pointer());
            return;
        }
        if (exponent == 1)
        {
            encrypted.duplicate_to(destination);
            return;
        }

        vector<BigPoly> exp_vector(exponent, encrypted);
        multiply_norelin_many(exp_vector, destination);

        // Binary exponentiation
        /*
        if (exponent % 2 == 0)
        {
            exponentiate_norelin(multiply_norelin(encrypted, encrypted), exponent >> 1, destination);
            return;
        }
        multiply_norelin(exponentiate_norelin(multiply_norelin(encrypted, encrypted), (exponent - 1) >> 1), encrypted, destination);
        */
    }
示例#15
0
    BigUInt poly_infty_norm(const BigPoly &poly)
    {
        if (poly.is_zero())
        {
            return BigUInt();
        }

        int coeff_count = poly.coeff_count();
        int coeff_bit_count = poly.coeff_bit_count();
        int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);

        BigUInt result(coeff_bit_count);
        util::poly_infty_norm(poly.data(), coeff_count, coeff_uint64_count, result.data());

        return result;
    }
    void Evaluator::multiply(const BigPoly &encrypted1, const BigPoly &encrypted2, BigPoly &destination)
    {
        // Extract encryption parameters.
        int coeff_count = poly_modulus_.coeff_count();
        int coeff_bit_count = poly_modulus_.coeff_bit_count();
        int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);

        // Verify parameters.
        if (encrypted1.coeff_count() != coeff_count || encrypted1.coeff_bit_count() != coeff_bit_count)
        {
            throw invalid_argument("encrypted1 is not valid for encryption parameters");
        }
        if (encrypted2.coeff_count() != coeff_count || encrypted2.coeff_bit_count() != coeff_bit_count)
        {
            throw invalid_argument("encrypted2 is not valid for encryption parameters");
        }
#ifdef _DEBUG
        if (encrypted1.significant_coeff_count() == coeff_count || !are_poly_coefficients_less_than(encrypted1, coeff_modulus_))
        {
            throw invalid_argument("encrypted1 is not valid for encryption parameters");
        }
        if (encrypted2.significant_coeff_count() == coeff_count || !are_poly_coefficients_less_than(encrypted2, coeff_modulus_))
        {
            throw invalid_argument("encrypted2 is not valid for encryption parameters");
        }
#endif
        if (destination.coeff_count() != coeff_count || destination.coeff_bit_count() != coeff_bit_count)
        {
            destination.resize(coeff_count, coeff_bit_count);
        }

        // Handle test-mode case.
        if (mode_ == TEST_MODE)
        {
            // Get pointer to inputs (duplicated if needed).
            ConstPointer encrypted1ptr = duplicate_poly_if_needed(encrypted1, encrypted1.pointer() == destination.pointer(), pool_);
            ConstPointer encrypted2ptr = duplicate_poly_if_needed(encrypted2, encrypted2.pointer() == destination.pointer(), pool_);

            multiply_poly_poly_polymod_coeffmod(encrypted1ptr.get(), encrypted2ptr.get(), polymod_, mod_, destination.pointer(), pool_);
            return;
        }

        // Multiply encrypted polynomials and perform key switching.
        Pointer product(allocate_poly(coeff_count, coeff_uint64_count, pool_));
        multiply(encrypted1.pointer(), encrypted2.pointer(), product.get());
        relinearize(product.get(), destination.pointer());
    }
    ChooserPoly ChooserEvaluator::multiply_many(const vector<ChooserPoly> &operands)
    {
        if (operands.empty())
        {
            throw invalid_argument("operands vector can not be empty");
        }

        int prod_max_coeff_count = 1;
        uint64_t growth_factor = 1;
        int prod_max_abs_value_bit_count = 1;
        vector<Computation*> comps;
        for (vector<ChooserPoly>::size_type i = 0; i < operands.size(); ++i)
        {
            // Throw if any of the operands is not initialized correctly
            if (operands[i].max_coeff_count_ <= 0 || operands[i].comp_ == nullptr)
            {
                throw invalid_argument("input operand is not correctly initialized");
            }

            // Return early if the product is trivially zero
            if (operands[i].max_abs_value_.is_zero())
            {
                return ChooserPoly(1, 0, new MultiplyManyComputation(comps));
            }

            prod_max_coeff_count += operands[i].max_coeff_count_ - 1;
            prod_max_abs_value_bit_count += operands[i].max_abs_value().significant_bit_count();

            growth_factor *= (i == 0 ? 1 : min(operands[i].max_coeff_count_, prod_max_coeff_count));

            comps.push_back(operands[i].comp_);
        }

        prod_max_abs_value_bit_count += get_significant_bit_count(growth_factor);
        int prod_max_abs_value_uint64_count = divide_round_up(prod_max_abs_value_bit_count, bits_per_uint64);

        Pointer prod_max_abs_value(allocate_zero_uint(prod_max_abs_value_uint64_count, pool_));
        *prod_max_abs_value.get() = growth_factor;
        for (vector<ChooserPoly>::size_type i = 0; i < operands.size(); ++i)
        {
            ConstPointer temp_pointer(duplicate_uint_if_needed(prod_max_abs_value.get(), prod_max_abs_value_uint64_count, prod_max_abs_value_uint64_count, true, pool_));
            multiply_uint_uint(temp_pointer.get(), prod_max_abs_value_uint64_count, operands[i].max_abs_value_.pointer(), operands[i].max_abs_value_.uint64_count(), prod_max_abs_value_uint64_count, prod_max_abs_value.get());
        }

        return ChooserPoly(prod_max_coeff_count, BigUInt(prod_max_abs_value_bit_count, prod_max_abs_value.get()), new MultiplyManyComputation(comps));
    }
    ChooserPoly ChooserEvaluator::add_many(const std::vector<ChooserPoly> &operands)
    {
        if (operands.empty())
        {
            throw invalid_argument("operands vector can not be empty");
        }

        int sum_max_coeff_count = operands[0].max_coeff_count_;
        vector<ChooserPoly>::size_type largest_abs_value_index = 0;
        for (vector<ChooserPoly>::size_type i = 0; i < operands.size(); ++i)
        {
            // Throw if any of the operands is not initialized correctly
            if (operands[i].max_coeff_count_ <= 0 || operands[i].comp_ == nullptr)
            {
                throw invalid_argument("input operand is not correctly initialized");
            }

            if (operands[i].max_coeff_count_ > sum_max_coeff_count)
            {
                sum_max_coeff_count = operands[i].max_coeff_count_;
            }
            if (compare_uint_uint(operands[i].max_abs_value_.pointer(), operands[i].max_abs_value_.uint64_count(), operands[largest_abs_value_index].max_abs_value_.pointer(), operands[largest_abs_value_index].max_abs_value_.uint64_count() > 0))
            {
                largest_abs_value_index = i;
            }
        }

        int sum_max_abs_value_bit_count = operands[largest_abs_value_index].max_abs_value_.significant_bit_count() + get_significant_bit_count(operands.size());
        int sum_max_abs_value_uint64_count = divide_round_up(sum_max_abs_value_bit_count, bits_per_uint64);
        Pointer sum_max_abs_value(allocate_zero_uint(sum_max_abs_value_uint64_count, pool_));

        vector<Computation*> comps;
        for (vector<ChooserPoly>::size_type i = 0; i < operands.size(); ++i)
        {
            add_uint_uint(operands[i].max_abs_value_.pointer(), operands[i].max_abs_value_.uint64_count(), sum_max_abs_value.get(), sum_max_abs_value_uint64_count, false, sum_max_abs_value_uint64_count, sum_max_abs_value.get());
            comps.push_back(operands[i].comp_);
        }

        return ChooserPoly(sum_max_coeff_count, BigUInt(sum_max_abs_value_bit_count, sum_max_abs_value.get()), new AddManyComputation(comps));
    }
    void Evaluator::relinearize(const BigPoly &encrypted, BigPoly &destination)
    {
        // Extract encryption parameters.
        int coeff_count = poly_modulus_.coeff_count();
        int coeff_bit_count = poly_modulus_.coeff_bit_count();
        int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);

        // Verify parameters.
        if (encrypted.coeff_count() != coeff_count || encrypted.coeff_bit_count() != coeff_bit_count)
        {
            throw invalid_argument("encrypted is not valid for encryption parameters");
        }
#ifdef _DEBUG
        if (encrypted.significant_coeff_count() == coeff_count || !are_poly_coefficients_less_than(encrypted, coeff_modulus_))
        {
            throw invalid_argument("encrypted is not valid for encryption parameters");
        }
#endif
        if (destination.coeff_count() != coeff_count || destination.coeff_bit_count() != coeff_bit_count)
        {
            destination.resize(coeff_count, coeff_bit_count);
        }

        // Handle test-mode case.
        if (mode_ == TEST_MODE)
        {
            set_poly_poly(encrypted.pointer(), coeff_count, coeff_uint64_count, destination.pointer());
            return;
        }

        // Get pointer to inputs (duplicated if needed).
        ConstPointer encryptedptr = duplicate_poly_if_needed(encrypted, encrypted.pointer() == destination.pointer(), pool_);

        // Relinearize polynomial.
        relinearize(encryptedptr.get(), destination.pointer());
    }
示例#20
0
    void KeyGenerator::generate()
    {
        // Handle test-mode case.
        if (mode_ == TEST_MODE)
        {
            public_key_.set_zero();
            public_key_[0] = 1;
            secret_key_.set_zero();
            secret_key_[0] = 1;
            for (int i = 0; i < evaluation_keys_.count(); ++i)
            {
                evaluation_keys_[i].set_zero();
                evaluation_keys_[i][0] = 1;
            }
            return;
        }

        // Extract encryption parameters.
        int coeff_count = poly_modulus_.coeff_count();
        int coeff_bit_count = poly_modulus_.coeff_bit_count();
        int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);

        // Loop until find a valid secret key.
        uint64_t *secret_key = secret_key_.pointer();
        set_zero_poly(coeff_count, coeff_uint64_count, secret_key);
        Pointer secret_key_inv(allocate_poly(coeff_count, coeff_uint64_count, pool_));
        while (true)
        {
            // Create noise with random [-1, 1] coefficients.
            set_poly_coeffs_zero_one_negone(secret_key);

            // Calculate secret_key * plaintext_modulus + 1.
            multiply_poly_scalar_coeffmod(secret_key, coeff_count, plain_modulus_.pointer(), mod_, secret_key, pool_);

            uint64_t *constant_coeff = get_poly_coeff(secret_key, 0, coeff_uint64_count);
            increment_uint_mod(constant_coeff, coeff_modulus_.pointer(), coeff_uint64_count, constant_coeff);

            // Attempt to invert secret_key.
            if (try_invert_poly_coeffmod(secret_key, poly_modulus_.pointer(), coeff_count, mod_, secret_key_inv.get(), pool_))
            {
                // Secret_key is invertible, so is valid
                break;
            }
        }

        // Calculate plaintext_modulus * noise * secret_key_inv.
        Pointer noise(allocate_poly(coeff_count, coeff_uint64_count, pool_));
        set_poly_coeffs_zero_one_negone(noise.get());
        uint64_t *public_key = public_key_.pointer();
        multiply_poly_poly_polymod_coeffmod(noise.get(), secret_key_inv.get(), polymod_, mod_, noise.get(), pool_);
        multiply_poly_scalar_coeffmod(noise.get(), coeff_count, plain_modulus_.pointer(), mod_, public_key, pool_);

        // Create evaluation keys.
        Pointer evaluation_factor(allocate_uint(coeff_uint64_count, pool_));
        set_uint(1, coeff_uint64_count, evaluation_factor.get());
        for (int i = 0; i < evaluation_keys_.count(); ++i)
        {
            // Multiply secret_key by evaluation_factor (mod coeff modulus).
            uint64_t *evaluation_key = evaluation_keys_[i].pointer();
            multiply_poly_scalar_coeffmod(secret_key, coeff_count, evaluation_factor.get(), mod_, evaluation_key, pool_);

            // Multiply public_key*normal noise and add into evaluation_key.
            set_poly_coeffs_normal(noise.get());
            multiply_poly_poly_polymod_coeffmod(noise.get(), public_key, polymod_, mod_, noise.get(), pool_);
            add_poly_poly_coeffmod(noise.get(), evaluation_key, coeff_count, coeff_modulus_.pointer(), coeff_uint64_count, evaluation_key);

            // Add-in more normal noise to evaluation_key.
            set_poly_coeffs_normal(noise.get());
            add_poly_poly_coeffmod(noise.get(), evaluation_key, coeff_count, coeff_modulus_.pointer(), coeff_uint64_count, evaluation_key);

            // Left shift evaluation factor.
            left_shift_uint(evaluation_factor.get(), decomposition_bit_count_, coeff_uint64_count, evaluation_factor.get());
        }
    }
示例#21
0
    KeyGenerator::KeyGenerator(const EncryptionParameters &parms) :
        poly_modulus_(parms.poly_modulus()), coeff_modulus_(parms.coeff_modulus()), plain_modulus_(parms.plain_modulus()),
        noise_standard_deviation_(parms.noise_standard_deviation()), noise_max_deviation_(parms.noise_max_deviation()),
        decomposition_bit_count_(parms.decomposition_bit_count()), mode_(parms.mode()),
        random_generator_(parms.random_generator() != nullptr ? parms.random_generator()->create() : UniformRandomGeneratorFactory::default_factory()->create())
    {
        // Verify required parameters are non-zero and non-nullptr.
        if (poly_modulus_.is_zero())
        {
            throw invalid_argument("poly_modulus cannot be zero");
        }
        if (coeff_modulus_.is_zero())
        {
            throw invalid_argument("coeff_modulus cannot be zero");
        }
        if (plain_modulus_.is_zero())
        {
            throw invalid_argument("plain_modulus cannot be zero");
        }
        if (noise_standard_deviation_ < 0)
        {
            throw invalid_argument("noise_standard_deviation must be non-negative");
        }
        if (noise_max_deviation_ < 0)
        {
            throw invalid_argument("noise_max_deviation must be non-negative");
        }
        if (decomposition_bit_count_ <= 0)
        {
            throw invalid_argument("decomposition_bit_count must be positive");
        }

        // Verify parameters.
        if (plain_modulus_ >= coeff_modulus_)
        {
            throw invalid_argument("plain_modulus must be smaller than coeff_modulus");
        }
        if (!are_poly_coefficients_less_than(poly_modulus_, coeff_modulus_))
        {
            throw invalid_argument("poly_modulus cannot have coefficients larger than coeff_modulus");
        }

        // Resize encryption parameters to consistent size.
        int coeff_count = poly_modulus_.significant_coeff_count();
        int coeff_bit_count = coeff_modulus_.significant_bit_count();
        int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);
        if (poly_modulus_.coeff_count() != coeff_count || poly_modulus_.coeff_bit_count() != coeff_bit_count)
        {
            poly_modulus_.resize(coeff_count, coeff_bit_count);
        }
        if (coeff_modulus_.bit_count() != coeff_bit_count)
        {
            coeff_modulus_.resize(coeff_bit_count);
        }
        if (plain_modulus_.bit_count() != coeff_bit_count)
        {
            plain_modulus_.resize(coeff_bit_count);
        }
        if (decomposition_bit_count_ > coeff_bit_count)
        {
            decomposition_bit_count_ = coeff_bit_count;
        }

        // Calculate -1 (mod coeff_modulus).
        coeff_modulus_minus_one_.resize(coeff_bit_count);
        decrement_uint(coeff_modulus_.pointer(), coeff_uint64_count, coeff_modulus_minus_one_.pointer());

        // Initialize public and secret key.
        public_key_.resize(coeff_count, coeff_bit_count);
        secret_key_.resize(coeff_count, coeff_bit_count);

        // Initialize evaluation keys.
        int evaluation_key_count = 0;
        Pointer evaluation_factor(allocate_uint(coeff_uint64_count, pool_));
        set_uint(1, coeff_uint64_count, evaluation_factor.get());
        while (!is_zero_uint(evaluation_factor.get(), coeff_uint64_count) && is_less_than_uint_uint(evaluation_factor.get(), coeff_modulus_.pointer(), coeff_uint64_count))
        {
            left_shift_uint(evaluation_factor.get(), decomposition_bit_count_, coeff_uint64_count, evaluation_factor.get());
            evaluation_key_count++;
        }
        evaluation_keys_.resize(evaluation_key_count);
        for (int i = 0; i < evaluation_key_count; ++i)
        {
            evaluation_keys_[i].resize(coeff_count, coeff_bit_count);
        }

        // Initialize moduli.
        polymod_ = PolyModulus(poly_modulus_.pointer(), coeff_count, coeff_uint64_count);
        mod_ = Modulus(coeff_modulus_.pointer(), coeff_uint64_count, pool_);
    }
    void Evaluator::multiply(const uint64_t *encrypted1, const uint64_t *encrypted2, uint64_t *destination)
    {
        // Extract encryption parameters.
        int coeff_count = poly_modulus_.coeff_count();
        int coeff_bit_count = poly_modulus_.coeff_bit_count();
        int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);

        // Clear destatintion.
        set_zero_poly(coeff_count, coeff_uint64_count, destination);

        // Determine if FFT can be used.
        bool use_fft = polymod_.coeff_count_power_of_two() >= 0 && polymod_.is_one_zero_one();

        if (use_fft)
        {
            // Use FFT to multiply polynomials.

            // Allocate polynomial to store product of two polynomials, with poly but no coeff modulo yet (and signed).
            int product_coeff_bit_count = coeff_bit_count + coeff_bit_count + get_significant_bit_count(static_cast<uint64_t>(coeff_count)) + 2;
            int product_coeff_uint64_count = divide_round_up(product_coeff_bit_count, bits_per_uint64);
            Pointer product(allocate_poly(coeff_count, product_coeff_uint64_count, pool_));

            // Use FFT to multiply polynomials.
            set_zero_uint(product_coeff_uint64_count, get_poly_coeff(product.get(), coeff_count - 1, product_coeff_uint64_count));
            fftmultiply_poly_poly_polymod(encrypted1, encrypted2, polymod_.coeff_count_power_of_two(), coeff_uint64_count, product_coeff_uint64_count, product.get(), pool_);

            // For each coefficient in product, multiply by plain_modulus and divide by coeff_modulus and then modulo by coeff_modulus.
            int plain_modulus_bit_count = plain_modulus_.significant_bit_count();
            int plain_modulus_uint64_count = divide_round_up(plain_modulus_bit_count, bits_per_uint64);
            int intermediate_bit_count = product_coeff_bit_count + plain_modulus_bit_count - 1;
            int intermediate_uint64_count = divide_round_up(intermediate_bit_count, bits_per_uint64);
            Pointer intermediate(allocate_uint(intermediate_uint64_count, pool_));
            Pointer quotient(allocate_uint(intermediate_uint64_count, pool_));
            for (int coeff_index = 0; coeff_index < coeff_count; ++coeff_index)
            {
                uint64_t *product_coeff = get_poly_coeff(product.get(), coeff_index, product_coeff_uint64_count);
                bool coeff_is_negative = is_high_bit_set_uint(product_coeff, product_coeff_uint64_count);
                if (coeff_is_negative)
                {
                    negate_uint(product_coeff, product_coeff_uint64_count, product_coeff);
                }
                multiply_uint_uint(product_coeff, product_coeff_uint64_count, plain_modulus_.pointer(), plain_modulus_uint64_count, intermediate_uint64_count, intermediate.get());
                add_uint_uint(intermediate.get(), wide_coeff_modulus_div_two_.pointer(), intermediate_uint64_count, intermediate.get());
                divide_uint_uint_inplace(intermediate.get(), wide_coeff_modulus_.pointer(), intermediate_uint64_count, quotient.get(), pool_);
                modulo_uint_inplace(quotient.get(), intermediate_uint64_count, mod_, pool_);
                uint64_t *dest_coeff = get_poly_coeff(destination, coeff_index, coeff_uint64_count);
                if (coeff_is_negative)
                {
                    negate_uint_mod(quotient.get(), coeff_modulus_.pointer(), coeff_uint64_count, dest_coeff);
                }
                else
                {
                    set_uint_uint(quotient.get(), coeff_uint64_count, dest_coeff);
                }
            }
        }
        else
        {
            // Use normal multiplication to multiply polynomials.

            // Allocate polynomial to store product of two polynomials, with no poly or coeff modulo yet.
            int product_coeff_count = coeff_count + coeff_count - 1;
            int product_coeff_bit_count = coeff_bit_count + coeff_bit_count + get_significant_bit_count(static_cast<uint64_t>(coeff_count));
            int product_coeff_uint64_count = divide_round_up(product_coeff_bit_count, bits_per_uint64);
            Pointer product(allocate_poly(product_coeff_count, product_coeff_uint64_count, pool_));

            // Multiply polynomials.
            multiply_poly_poly(encrypted1, coeff_count, coeff_uint64_count, encrypted2, coeff_count, coeff_uint64_count, product_coeff_count, product_coeff_uint64_count, product.get(), pool_);

            // For each coefficient in product, multiply by plain_modulus and divide by coeff_modulus and then modulo by coeff_modulus.
            int plain_modulus_bit_count = plain_modulus_.significant_bit_count();
            int plain_modulus_uint64_count = divide_round_up(plain_modulus_bit_count, bits_per_uint64);
            int intermediate_bit_count = product_coeff_bit_count + plain_modulus_bit_count;
            int intermediate_uint64_count = divide_round_up(intermediate_bit_count, bits_per_uint64);
            Pointer intermediate(allocate_uint(intermediate_uint64_count, pool_));
            Pointer quotient(allocate_uint(intermediate_uint64_count, pool_));
            Pointer productmoded(allocate_poly(product_coeff_count, coeff_uint64_count, pool_));
            for (int coeff_index = 0; coeff_index < product_coeff_count; ++coeff_index)
            {
                const uint64_t *product_coeff = get_poly_coeff(product.get(), coeff_index, product_coeff_uint64_count);
                multiply_uint_uint(product_coeff, product_coeff_uint64_count, plain_modulus_.pointer(), plain_modulus_uint64_count, intermediate_uint64_count, intermediate.get());
                add_uint_uint(intermediate.get(), wide_coeff_modulus_div_two_.pointer(), intermediate_uint64_count, intermediate.get());
                divide_uint_uint_inplace(intermediate.get(), wide_coeff_modulus_.pointer(), intermediate_uint64_count, quotient.get(), pool_);
                modulo_uint_inplace(quotient.get(), intermediate_uint64_count, mod_, pool_);
                uint64_t *productmoded_coeff = get_poly_coeff(productmoded.get(), coeff_index, coeff_uint64_count);
                set_uint_uint(quotient.get(), coeff_uint64_count, productmoded_coeff);
            }

            // Perform polynomial modulo.
            modulo_poly_inplace(productmoded.get(), product_coeff_count, polymod_, mod_, pool_);

            // Copy to destination.
            set_poly_poly(productmoded.get(), coeff_count, coeff_uint64_count, destination);
        }
    }
    Evaluator::Evaluator(const EncryptionParameters &parms, const EvaluationKeys &evaluation_keys) :
        poly_modulus_(parms.poly_modulus()), coeff_modulus_(parms.coeff_modulus()), plain_modulus_(parms.plain_modulus()),
        decomposition_bit_count_(parms.decomposition_bit_count()), evaluation_keys_(evaluation_keys), mode_(parms.mode())
    {
        // Verify required parameters are non-zero and non-nullptr.
        if (poly_modulus_.is_zero())
        {
            throw invalid_argument("poly_modulus cannot be zero");
        }
        if (coeff_modulus_.is_zero())
        {
            throw invalid_argument("coeff_modulus cannot be zero");
        }
        if (plain_modulus_.is_zero())
        {
            throw invalid_argument("plain_modulus cannot be zero");
        }
        if (decomposition_bit_count_ <= 0)
        {
            throw invalid_argument("decomposition_bit_count must be positive");
        }
        if (evaluation_keys_.count() == 0)
        {
            throw invalid_argument("evaluation_keys cannot be empty");
        }

        // Verify parameters.
        if (plain_modulus_ >= coeff_modulus_)
        {
            throw invalid_argument("plain_modulus must be smaller than coeff_modulus");
        }
        if (!are_poly_coefficients_less_than(poly_modulus_, coeff_modulus_))
        {
            throw invalid_argument("poly_modulus cannot have coefficients larger than coeff_modulus");
        }

        // Resize encryption parameters to consistent size.
        int coeff_count = poly_modulus_.significant_coeff_count();
        int coeff_bit_count = coeff_modulus_.significant_bit_count();
        int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);
        if (poly_modulus_.coeff_count() != coeff_count || poly_modulus_.coeff_bit_count() != coeff_bit_count)
        {
            poly_modulus_.resize(coeff_count, coeff_bit_count);
        }
        if (coeff_modulus_.bit_count() != coeff_bit_count)
        {
            coeff_modulus_.resize(coeff_bit_count);
        }
        if (plain_modulus_.bit_count() != coeff_bit_count)
        {
            plain_modulus_.resize(coeff_bit_count);
        }
        if (decomposition_bit_count_ > coeff_bit_count)
        {
            decomposition_bit_count_ = coeff_bit_count;
        }

        // Determine correct number of evaluation keys.
        int evaluation_key_count = 0;
        Pointer evaluation_factor(allocate_uint(coeff_uint64_count, pool_));
        set_uint(1, coeff_uint64_count, evaluation_factor.get());
        while (!is_zero_uint(evaluation_factor.get(), coeff_uint64_count) && is_less_than_uint_uint(evaluation_factor.get(), coeff_modulus_.pointer(), coeff_uint64_count))
        {
            left_shift_uint(evaluation_factor.get(), decomposition_bit_count_, coeff_uint64_count, evaluation_factor.get());
            evaluation_key_count++;
        }

        // Verify evaluation keys.
        if (evaluation_keys_.count() != evaluation_key_count)
        {
            throw invalid_argument("evaluation_keys is not valid for encryption parameters");
        }
        for (int i = 0; i < evaluation_keys_.count(); ++i)
        {
            BigPoly &evaluation_key = evaluation_keys_[i];
            if (evaluation_key.coeff_count() != coeff_count || evaluation_key.coeff_bit_count() != coeff_bit_count ||
                evaluation_key.significant_coeff_count() == coeff_count || !are_poly_coefficients_less_than(evaluation_key, coeff_modulus_))
            {
                throw invalid_argument("evaluation_keys is not valid for encryption parameters");
            }
        }

        // Calculate coeff_modulus / plain_modulus.
        coeff_div_plain_modulus_.resize(coeff_bit_count);
        Pointer temp(allocate_uint(coeff_uint64_count, pool_));
        divide_uint_uint(coeff_modulus_.pointer(), plain_modulus_.pointer(), coeff_uint64_count, coeff_div_plain_modulus_.pointer(), temp.get(), pool_);

        // Calculate (plain_modulus + 1) / 2.
        plain_upper_half_threshold_.resize(coeff_bit_count);
        half_round_up_uint(plain_modulus_.pointer(), coeff_uint64_count, plain_upper_half_threshold_.pointer());

        // Calculate coeff_modulus - plain_modulus.
        plain_upper_half_increment_.resize(coeff_bit_count);
        sub_uint_uint(coeff_modulus_.pointer(), plain_modulus_.pointer(), coeff_uint64_count, plain_upper_half_increment_.pointer());

        // Calculate (plain_modulus + 1) / 2 * coeff_div_plain_modulus.
        upper_half_threshold_.resize(coeff_bit_count);
        multiply_truncate_uint_uint(plain_upper_half_threshold_.pointer(), coeff_div_plain_modulus_.pointer(), coeff_uint64_count, upper_half_threshold_.pointer());

        // Calculate upper_half_increment.
        upper_half_increment_.resize(coeff_bit_count);
        multiply_truncate_uint_uint(plain_modulus_.pointer(), coeff_div_plain_modulus_.pointer(), coeff_uint64_count, upper_half_increment_.pointer());
        sub_uint_uint(coeff_modulus_.pointer(), upper_half_increment_.pointer(), coeff_uint64_count, upper_half_increment_.pointer());

        // Widen coeff modulus.
        int product_coeff_bit_count = coeff_bit_count + coeff_bit_count + get_significant_bit_count(static_cast<uint64_t>(coeff_count));
        int plain_modulus_bit_count = plain_modulus_.significant_bit_count();
        int wide_bit_count = product_coeff_bit_count + plain_modulus_bit_count;
        int wide_uint64_count = divide_round_up(wide_bit_count, bits_per_uint64);
        wide_coeff_modulus_.resize(wide_bit_count);
        wide_coeff_modulus_ = coeff_modulus_;

        // Calculate wide_coeff_modulus_ / 2.
        wide_coeff_modulus_div_two_.resize(wide_bit_count);
        right_shift_uint(wide_coeff_modulus_.pointer(), 1, wide_uint64_count, wide_coeff_modulus_div_two_.pointer());

        // Initialize moduli.
        polymod_ = PolyModulus(poly_modulus_.pointer(), coeff_count, coeff_uint64_count);
        if (mode_ == TEST_MODE)
        {
            mod_ = Modulus(plain_modulus_.pointer(), coeff_uint64_count, pool_);
        }
        else
        {
            mod_ = Modulus(coeff_modulus_.pointer(), coeff_uint64_count, pool_);
        }
    }
示例#24
0
    void KeyGenerator::generate(const BigPoly &secret_key, uint64_t power)
    {
        // Validate arguments.
        if (secret_key.is_zero())
        {
            throw invalid_argument("secret_key cannot be zero");
        }
        if (power == 0)
        {
            throw invalid_argument("power cannot be zero");
        }

        // Handle test-mode case.
        if (mode_ == TEST_MODE)
        {
            public_key_.set_zero();
            public_key_[0] = 1;
            secret_key_.set_zero();
            secret_key_[0] = 1;
            for (int i = 0; i < evaluation_keys_.count(); ++i)
            {
                evaluation_keys_[i].set_zero();
                evaluation_keys_[i][0] = 1;
            }
            return;
        }

        // Extract encryption parameters.
        int coeff_count = poly_modulus_.coeff_count();
        int coeff_bit_count = poly_modulus_.coeff_bit_count();
        int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);

        // Verify secret key looks valid.
        secret_key_ = secret_key;
        if (secret_key_.coeff_count() != coeff_count || secret_key_.coeff_bit_count() != coeff_bit_count)
        {
            throw invalid_argument("secret_key is not valid for encryption parameters");
        }
#ifdef _DEBUG
        if (secret_key_.significant_coeff_count() == coeff_count || !are_poly_coefficients_less_than(secret_key_, coeff_modulus_))
        {
            throw invalid_argument("secret_key is not valid for encryption parameters");
        }
#endif

        // Raise level of secret key.
        if (power > 1)
        {
            exponentiate_poly_polymod_coeffmod(secret_key_.pointer(), &power, 1, polymod_, mod_, secret_key_.pointer(), pool_);
        }

        // Attempt to invert secret_key.
        Pointer secret_key_inv(allocate_poly(coeff_count, coeff_uint64_count, pool_));
        if (!try_invert_poly_coeffmod(secret_key_.pointer(), poly_modulus_.pointer(), coeff_count, mod_, secret_key_inv.get(), pool_))
        {
            // Secret_key is not invertible, so not valid.
            throw invalid_argument("secret_key is not valid for encryption parameters");
        }

        // Calculate plaintext_modulus * noise * secret_key_inv.
        Pointer noise(allocate_poly(coeff_count, coeff_uint64_count, pool_));
        set_poly_coeffs_zero_one_negone(noise.get());
        uint64_t *public_key = public_key_.pointer();
        multiply_poly_poly_polymod_coeffmod(noise.get(), secret_key_inv.get(), polymod_, mod_, noise.get(), pool_);
        multiply_poly_scalar_coeffmod(noise.get(), coeff_count, plain_modulus_.pointer(), mod_, public_key, pool_);

        // Create evaluation keys.
        Pointer evaluation_factor(allocate_uint(coeff_uint64_count, pool_));
        set_uint(1, coeff_uint64_count, evaluation_factor.get());
        for (int i = 0; i < evaluation_keys_.count(); ++i)
        {
            // Multiply secret_key by evaluation_factor (mod coeff modulus).
            uint64_t *evaluation_key = evaluation_keys_[i].pointer();
            multiply_poly_scalar_coeffmod(secret_key_.pointer(), coeff_count, evaluation_factor.get(), mod_, evaluation_key, pool_);

            // Multiply public_key*normal noise and add into evaluation_key.
            set_poly_coeffs_normal(noise.get());
            multiply_poly_poly_polymod_coeffmod(noise.get(), public_key, polymod_, mod_, noise.get(), pool_);
            add_poly_poly_coeffmod(noise.get(), evaluation_key, coeff_count, coeff_modulus_.pointer(), coeff_uint64_count, evaluation_key);

            // Add-in more normal noise to evaluation_key.
            set_poly_coeffs_normal(noise.get());
            add_poly_poly_coeffmod(noise.get(), evaluation_key, coeff_count, coeff_modulus_.pointer(), coeff_uint64_count, evaluation_key);

            // Left shift evaluation factor.
            left_shift_uint(evaluation_factor.get(), decomposition_bit_count_, coeff_uint64_count, evaluation_factor.get());
        }
    }
示例#25
0
/// Try to print the list of completions l with the prefix prefix using cols as the number of
/// columns. Return true if the completion list was printed, false if the terminal is to narrow for
/// the specified number of columns. Always succeeds if cols is 1.
bool pager_t::completion_try_print(size_t cols, const wcstring &prefix, const comp_info_list_t &lst,
                                   page_rendering_t *rendering, size_t suggested_start_row) const {
    // The calculated preferred width of each column.
    int pref_width[PAGER_MAX_COLS] = {0};
    // The calculated minimum width of each column.
    int min_width[PAGER_MAX_COLS] = {0};
    // If the list can be printed with this width, width will contain the width of each column.
    int *width = pref_width;

    // Set to one if the list should be printed at this width.
    bool print = false;

    // Compute the effective term width and term height, accounting for disclosure.
    size_t term_width = this->available_term_width;
    size_t term_height =
        this->available_term_height - 1 -
        (search_field_shown ? 1 : 0);  // we always subtract 1 to make room for a comment row
    if (!this->fully_disclosed) {
        term_height = mini(term_height, (size_t)PAGER_UNDISCLOSED_MAX_ROWS);
    }

    size_t row_count = divide_round_up(lst.size(), cols);

    // We have more to disclose if we are not fully disclosed and there's more rows than we have in
    // our term height.
    if (!this->fully_disclosed && row_count > term_height) {
        rendering->remaining_to_disclose = row_count - term_height;
    } else {
        rendering->remaining_to_disclose = 0;
    }

    // If we have only one row remaining to disclose, then squelch the comment row. This prevents us
    // from consuming a line to show "...and 1 more row".
    if (!this->fully_disclosed && rendering->remaining_to_disclose == 1) {
        term_height += 1;
        rendering->remaining_to_disclose = 0;
    }

    size_t pref_tot_width = 0;
    size_t min_tot_width = 0;

    // Skip completions on tiny terminals.
    if (term_width < PAGER_MIN_WIDTH) return true;

    // Calculate how wide the list would be.
    for (size_t col = 0; col < cols; col++) {
        for (size_t row = 0; row < row_count; row++) {
            int pref, min;
            const comp_t *c;
            if (lst.size() <= col * row_count + row) continue;

            c = &lst.at(col * row_count + row);
            pref = c->pref_width;
            min = c->min_width;

            if (col != cols - 1) {
                pref += 2;
                min += 2;
            }
            min_width[col] = maxi(min_width[col], min);
            pref_width[col] = maxi(pref_width[col], pref);
        }
        min_tot_width += min_width[col];
        pref_tot_width += pref_width[col];
    }

    // Force fit if one column.
    if (cols == 1) {
        if (pref_tot_width > term_width) {
            pref_width[0] = term_width;
        }
        width = pref_width;
        print = true;
    } else if (pref_tot_width <= term_width) {
        // Terminal is wide enough. Print the list!
        width = pref_width;
        print = true;
    }

    if (print) {
        // Determine the starting and stop row.
        size_t start_row = 0, stop_row = 0;
        if (row_count <= term_height) {
            // Easy, we can show everything.
            start_row = 0;
            stop_row = row_count;
        } else {
            // We can only show part of the full list. Determine which part based on the
            // suggested_start_row.
            assert(row_count > term_height);
            size_t last_starting_row = row_count - term_height;
            start_row = mini(suggested_start_row, last_starting_row);
            stop_row = start_row + term_height;
            assert(start_row >= 0 && start_row <= last_starting_row);
        }

        assert(stop_row >= start_row);
        assert(stop_row <= row_count);
        assert(stop_row - start_row <= term_height);
        completion_print(cols, width, start_row, stop_row, prefix, lst, rendering);

        // Ellipsis helper string. Either empty or containing the ellipsis char.
        const wchar_t ellipsis_string[] = {ellipsis_char == L'\x2026' ? L'\x2026' : L'\0', L'\0'};

        // Add the progress line. It's a "more to disclose" line if necessary, or a row listing if
        // it's scrollable; otherwise ignore it.
        wcstring progress_text;
        if (rendering->remaining_to_disclose == 1) {
            // I don't expect this case to ever happen.
            progress_text = format_string(_(L"%lsand 1 more row"), ellipsis_string);
        } else if (rendering->remaining_to_disclose > 1) {
            progress_text = format_string(_(L"%lsand %lu more rows"), ellipsis_string,
                                          (unsigned long)rendering->remaining_to_disclose);
        } else if (start_row > 0 || stop_row < row_count) {
            // We have a scrollable interface. The +1 here is because we are zero indexed, but want
            // to present things as 1-indexed. We do not add 1 to stop_row or row_count because
            // these are the "past the last value".
            progress_text =
                format_string(_(L"rows %lu to %lu of %lu"), start_row + 1, stop_row, row_count);
        } else if (completion_infos.empty() && !unfiltered_completion_infos.empty()) {
            // Everything is filtered.
            progress_text = _(L"(no matches)");
        }

        if (!progress_text.empty()) {
            line_t &line = rendering->screen_data.add_line();
            print_max(progress_text, highlight_spec_pager_progress |
                                         highlight_make_background(highlight_spec_pager_progress),
                      term_width, true /* has_more */, &line);
        }

        if (search_field_shown) {
            // Add the search field.
            wcstring search_field_text = search_field_line.text;
            // Append spaces to make it at least the required width.
            if (search_field_text.size() < PAGER_SEARCH_FIELD_WIDTH) {
                search_field_text.append(PAGER_SEARCH_FIELD_WIDTH - search_field_text.size(), L' ');
            }
            line_t *search_field = &rendering->screen_data.insert_line_at_index(0);

            // We limit the width to term_width - 1.
            int search_field_written = print_max(SEARCH_FIELD_PROMPT, highlight_spec_normal,
                                                 term_width - 1, false, search_field);
            print_max(search_field_text, highlight_modifier_force_underline,
                      term_width - search_field_written - 1, false, search_field);
        }
    }
    return print;
}
    HkeyGen::HkeyGen(const EncryptionParameters &parms, const BigPoly &secret_key) :
        poly_modulus_(parms.poly_modulus()), coeff_modulus_(parms.coeff_modulus()), plain_modulus_(parms.plain_modulus()), secret_key_(secret_key), orig_plain_modulus_bit_count_(parms.plain_modulus().significant_bit_count())
    {
        // Verify required parameters are non-zero and non-nullptr.
        if (poly_modulus_.is_zero())
        {
            throw invalid_argument("poly_modulus cannot be zero");
        }
        if (coeff_modulus_.is_zero())
        {
            throw invalid_argument("coeff_modulus cannot be zero");
        }
        if (plain_modulus_.is_zero())
        {
            throw invalid_argument("plain_modulus cannot be zero");
        }

        if (secret_key_.is_zero())
        {
            throw invalid_argument("secret_key cannot be zero");
        }

        // Verify parameters.
        if (plain_modulus_ >= coeff_modulus_)
        {
            throw invalid_argument("plain_modulus must be smaller than coeff_modulus");
        }
        if (!are_poly_coefficients_less_than(poly_modulus_, coeff_modulus_))
        {
            throw invalid_argument("poly_modulus cannot have coefficients larger than coeff_modulus");
        }

        // Resize encryption parameters to consistent size.
        int coeff_count = poly_modulus_.significant_coeff_count();
        int coeff_bit_count = coeff_modulus_.significant_bit_count();
        int coeff_uint64_count = divide_round_up(coeff_bit_count, bits_per_uint64);
        if (poly_modulus_.coeff_count() != coeff_count || poly_modulus_.coeff_bit_count() != coeff_bit_count)
        {
            poly_modulus_.resize(coeff_count, coeff_bit_count);
        }
        if (coeff_modulus_.bit_count() != coeff_bit_count)
        {
            coeff_modulus_.resize(coeff_bit_count);
        }
        if (plain_modulus_.bit_count() != coeff_bit_count)
        {
            plain_modulus_.resize(coeff_bit_count);
        }
        if (secret_key_.coeff_count() != coeff_count || secret_key_.coeff_bit_count() != coeff_bit_count ||
            secret_key_.significant_coeff_count() == coeff_count || !are_poly_coefficients_less_than(secret_key_, coeff_modulus_))
        {
            throw invalid_argument("secret_key is not valid for encryption parameters");
        }

        // Set the secret_key_array to have size 1 (first power of secret) 
        secret_key_array_.resize(1, coeff_count, coeff_bit_count);
        set_poly_poly(secret_key_.pointer(), coeff_count, coeff_uint64_count, secret_key_array_.pointer(0));

        MemoryPool &pool = *MemoryPool::default_pool();

        // Calculate coeff_modulus / plain_modulus.
        coeff_div_plain_modulus_.resize(coeff_bit_count);
        Pointer temp(allocate_uint(coeff_uint64_count, pool));
        divide_uint_uint(coeff_modulus_.pointer(), plain_modulus_.pointer(), coeff_uint64_count, coeff_div_plain_modulus_.pointer(), temp.get(), pool);

        // Calculate coeff_modulus / plain_modulus / 2.
        coeff_div_plain_modulus_div_two_.resize(coeff_bit_count);
        right_shift_uint(coeff_div_plain_modulus_.pointer(), 1, coeff_uint64_count, coeff_div_plain_modulus_div_two_.pointer());

        // Calculate coeff_modulus / 2.
        upper_half_threshold_.resize(coeff_bit_count);
        half_round_up_uint(coeff_modulus_.pointer(), coeff_uint64_count, upper_half_threshold_.pointer());

        // Calculate upper_half_increment.
        upper_half_increment_.resize(coeff_bit_count);
        multiply_truncate_uint_uint(plain_modulus_.pointer(), coeff_div_plain_modulus_.pointer(), coeff_uint64_count, upper_half_increment_.pointer());
        sub_uint_uint(coeff_modulus_.pointer(), upper_half_increment_.pointer(), coeff_uint64_count, upper_half_increment_.pointer());

        // Initialize moduli.
        polymod_ = PolyModulus(poly_modulus_.pointer(), coeff_count, coeff_uint64_count);
        mod_ = Modulus(coeff_modulus_.pointer(), coeff_uint64_count, pool);  
    }
示例#27
0
inline Integer1 round_up_to_multiple(Integer1 const n, Integer2 const d) {
  BOOST_MPL_ASSERT((boost::is_integral<Integer1>));
  BOOST_MPL_ASSERT((boost::is_integral<Integer2>));
  return divide_round_up(n,d)*d;
}
示例#28
0
bool pager_t::completion_try_print(size_t cols, const wcstring &prefix, const comp_info_list_t &lst, page_rendering_t *rendering, size_t suggested_start_row) const
{
    /*
      The calculated preferred width of each column
    */
    int pref_width[PAGER_MAX_COLS] = {0};
    /*
      The calculated minimum width of each column
    */
    int min_width[PAGER_MAX_COLS] = {0};
    /*
      If the list can be printed with this width, width will contain the width of each column
    */
    int *width=pref_width;

    /* Set to one if the list should be printed at this width */
    bool print = false;

    /* Compute the effective term width and term height, accounting for disclosure */
    int term_width = this->available_term_width;
    int term_height = this->available_term_height - 1 - (search_field_shown ? 1 : 0); // we always subtract 1 to make room for a comment row
    if (! this->fully_disclosed)
    {
        term_height = mini(term_height, PAGER_UNDISCLOSED_MAX_ROWS);
    }

    size_t row_count = divide_round_up(lst.size(), cols);

    /* We have more to disclose if we are not fully disclosed and there's more rows than we have in our term height */
    if (! this->fully_disclosed && row_count > term_height)
    {
        rendering->remaining_to_disclose = row_count - term_height;
    }
    else
    {
        rendering->remaining_to_disclose = 0;
    }

    int pref_tot_width=0;
    int min_tot_width = 0;

    /* Skip completions on tiny terminals */
    if (term_width < PAGER_MIN_WIDTH)
        return true;

    /* Calculate how wide the list would be */
    for (long col = 0; col < cols; col++)
    {
        for (long row = 0; row<row_count; row++)
        {
            int pref,min;
            const comp_t *c;
            if (lst.size() <= col*row_count + row)
                continue;

            c = &lst.at(col*row_count + row);
            pref = c->pref_width;
            min = c->min_width;

            if (col != cols-1)
            {
                pref += 2;
                min += 2;
            }
            min_width[col] = maxi(min_width[col],
                                  min);
            pref_width[col] = maxi(pref_width[col],
                                   pref);
        }
        min_tot_width += min_width[col];
        pref_tot_width += pref_width[col];
    }
    /*
      Force fit if one column
    */
    if (cols == 1)
    {
        if (pref_tot_width > term_width)
        {
            pref_width[0] = term_width;
        }
        width = pref_width;
        print = true;
    }
    else if (pref_tot_width <= term_width)
    {
        /* Terminal is wide enough. Print the list! */
        width = pref_width;
        print = true;
    }
    else
    {
        long next_rows = (lst.size()-1)/(cols-1)+1;
        /*    fwprintf( stderr,
          L"cols %d, min_tot %d, term %d, rows=%d, nextrows %d, termrows %d, diff %d\n",
          cols,
          min_tot_width, term_width,
          rows, next_rows, term_height,
          pref_tot_width-term_width );
        */
        if (min_tot_width < term_width &&
                (((row_count < term_height) && (next_rows >= term_height)) ||
                 (pref_tot_width-term_width< 4 && cols < 3)))
        {
            /*
              Terminal almost wide enough, or squeezing makes the
              whole list fit on-screen.

              This part of the code is really important. People hate
              having to scroll through the completion list. In cases
              where there are a huge number of completions, it can't
              be helped, but it is not uncommon for the completions to
              _almost_ fit on one screen. In those cases, it is almost
              always desirable to 'squeeze' the completions into a
              single page.

              If we are using N columns and can get everything to
              fit using squeezing, but everything would also fit
              using N-1 columns, don't try.
            */

            int tot_width = min_tot_width;
            width = min_width;

            while (tot_width < term_width)
            {
                for (long i=0; (i<cols) && (tot_width < term_width); i++)
                {
                    if (width[i] < pref_width[i])
                    {
                        width[i]++;
                        tot_width++;
                    }
                }
            }
            print = true;
        }
    }

    if (print)
    {
        /* Determine the starting and stop row */
        size_t start_row = 0, stop_row = 0;
        if (row_count <= term_height)
        {
            /* Easy, we can show everything */
            start_row = 0;
            stop_row = row_count;
        }
        else
        {
            /* We can only show part of the full list. Determine which part based on the suggested_start_row */
            assert(row_count > term_height);
            size_t last_starting_row = row_count - term_height;
            start_row = mini(suggested_start_row, last_starting_row);
            stop_row = start_row + term_height;
            assert(start_row >= 0 && start_row <= last_starting_row);
        }

        assert(stop_row >= start_row);
        assert(stop_row <= row_count);
        assert(stop_row - start_row <= term_height);
        completion_print(cols, width, start_row, stop_row, prefix, lst, rendering);

        /* Ellipsis helper string. Either empty or containing the ellipsis char */
        const wchar_t ellipsis_string[] = {ellipsis_char == L'\x2026' ? L'\x2026' : L'\0', L'\0'};

        /* Add the progress line. It's a "more to disclose" line if necessary, or a row listing if it's scrollable; otherwise ignore it */
        wcstring progress_text;
        if (rendering->remaining_to_disclose == 1)
        {
            /* I don't expect this case to ever happen */
            progress_text = format_string(L"%lsand 1 more row", ellipsis_string);
        }
        else if (rendering->remaining_to_disclose > 1)
        {
            progress_text = format_string(L"%lsand %lu more rows", ellipsis_string, (unsigned long)rendering->remaining_to_disclose);
        }
        else if (start_row > 0 || stop_row < row_count)
        {
            /* We have a scrollable interface. The +1 here is because we are zero indexed, but want to present things as 1-indexed. We do not add 1 to stop_row or row_count because these are the "past the last value" */
            progress_text = format_string(L"rows %lu to %lu of %lu", start_row + 1, stop_row, row_count);
        }
        else if (completion_infos.empty() && ! unfiltered_completion_infos.empty())
        {
            /* Everything is filtered */
            progress_text = L"(no matches)";
        }

        if (! progress_text.empty())
        {
            line_t &line = rendering->screen_data.add_line();
            print_max(progress_text.c_str(), highlight_spec_pager_progress | highlight_make_background(highlight_spec_pager_progress), term_width, true /* has_more */, &line);
        }

        if (search_field_shown)
        {
            /* Add the search field */
            wcstring search_field_text = search_field_line.text;
            /* Append spaces to make it at least the required width */
            if (search_field_text.size() < PAGER_SEARCH_FIELD_WIDTH)
            {
                search_field_text.append(PAGER_SEARCH_FIELD_WIDTH - search_field_text.size(), L' ');
            }
            line_t *search_field = &rendering->screen_data.insert_line_at_index(0);

            /* We limit the width to term_width - 1 */
            int search_field_written = print_max(SEARCH_FIELD_PROMPT, highlight_spec_normal, term_width - 1, false, search_field);
            search_field_written += print_max(search_field_text, highlight_modifier_force_underline, term_width - search_field_written - 1, false, search_field);
        }

    }
    return print;
}
示例#29
0
struct pthreadpool* pthreadpool_create(size_t threads_count) {
	if (threads_count == 0) {
		threads_count = (size_t) sysconf(_SC_NPROCESSORS_ONLN);
	}
#if !defined(__ANDROID__)
	struct pthreadpool* threadpool = NULL;
	if (posix_memalign((void**) &threadpool, 64, sizeof(struct pthreadpool) + threads_count * sizeof(struct thread_info)) != 0) {
#else
	/*
	 * Android didn't get posix_memalign until API level 17 (Android 4.2).
	 * Use (otherwise obsolete) memalign function on Android platform.
	 */
	struct pthreadpool* threadpool = memalign(64, sizeof(struct pthreadpool) + threads_count * sizeof(struct thread_info));
	if (threadpool == NULL) {
#endif
		return NULL;
	}
	memset(threadpool, 0, sizeof(struct pthreadpool) + threads_count * sizeof(struct thread_info));
	threadpool->threads_count = threads_count;
	pthread_mutex_init(&threadpool->execution_mutex, NULL);
	pthread_mutex_init(&threadpool->barrier_mutex, NULL);
	pthread_cond_init(&threadpool->barrier_condvar, NULL);
	pthread_mutex_init(&threadpool->state_mutex, NULL);
	pthread_cond_init(&threadpool->state_condvar, NULL);

	for (size_t tid = 0; tid < threads_count; tid++) {
		threadpool->threads[tid].thread_number = tid;
		pthread_create(&threadpool->threads[tid].thread_object, NULL, &thread_main, &threadpool->threads[tid]);
	}

	/* Wait until all threads initialize */
	wait_worker_threads(threadpool);
	return threadpool;
}

size_t pthreadpool_get_threads_count(struct pthreadpool* threadpool) {
	return threadpool->threads_count;
}

void pthreadpool_compute_1d(
	struct pthreadpool* threadpool,
	pthreadpool_function_1d_t function,
	void* argument,
	size_t range)
{
	if (threadpool == NULL) {
		/* No thread pool provided: execute function sequentially on the calling thread */
		for (size_t i = 0; i < range; i++) {
			function(argument, i);
		}
	} else {
		/* Protect the global threadpool structures */
		pthread_mutex_lock(&threadpool->execution_mutex);

		/* Lock the state variables to ensure that threads don't start processing before they observe complete state */
		pthread_mutex_lock(&threadpool->state_mutex);

		/* Setup global arguments */
		threadpool->function = function;
		threadpool->argument = argument;

		/* Spread the work between threads */
		for (size_t tid = 0; tid < threadpool->threads_count; tid++) {
			struct thread_info* thread = &threadpool->threads[tid];
			thread->range_start = multiply_divide(range, tid, threadpool->threads_count);
			thread->range_end = multiply_divide(range, tid + 1, threadpool->threads_count);
			thread->range_length = thread->range_end - thread->range_start;
			thread->state = thread_state_compute_1d;
		}

		/* Unlock the state variables before waking up the threads for better performance */
		pthread_mutex_unlock(&threadpool->state_mutex);

		/* Wake up the threads */
		wakeup_worker_threads(threadpool);

		/* Wait until the threads finish computation */
		wait_worker_threads(threadpool);

		/* Unprotect the global threadpool structures */
		pthread_mutex_unlock(&threadpool->execution_mutex);
	}
}

struct compute_1d_tiled_context {
	pthreadpool_function_1d_tiled_t function;
	void* argument;
	size_t range;
	size_t tile;
};

static void compute_1d_tiled(const struct compute_1d_tiled_context* context, size_t linear_index) {
	const size_t tile_index = linear_index;
	const size_t index = tile_index * context->tile;
	const size_t tile = min(context->tile, context->range - index);
	context->function(context->argument, index, tile);
}

void pthreadpool_compute_1d_tiled(
	pthreadpool_t threadpool,
	pthreadpool_function_1d_tiled_t function,
	void* argument,
	size_t range,
	size_t tile)
{
	if (threadpool == NULL) {
		/* No thread pool provided: execute function sequentially on the calling thread */
		for (size_t i = 0; i < range; i += tile) {
			function(argument, i, min(range - i, tile));
		}
	} else {
		/* Execute in parallel on the thread pool using linearized index */
		const size_t tile_range = divide_round_up(range, tile);
		struct compute_1d_tiled_context context = {
			.function = function,
			.argument = argument,
			.range = range,
			.tile = tile
		};
		pthreadpool_compute_1d(threadpool, (pthreadpool_function_1d_t) compute_1d_tiled, &context, tile_range);
	}
}

struct compute_2d_context {
	pthreadpool_function_2d_t function;
	void* argument;
	struct fxdiv_divisor_size_t range_j;
};

static void compute_2d(const struct compute_2d_context* context, size_t linear_index) {
	const struct fxdiv_divisor_size_t range_j = context->range_j;
	const struct fxdiv_result_size_t index = fxdiv_divide_size_t(linear_index, range_j);
	context->function(context->argument, index.quotient, index.remainder);
}

void pthreadpool_compute_2d(
	struct pthreadpool* threadpool,
	pthreadpool_function_2d_t function,
	void* argument,
	size_t range_i,
	size_t range_j)
{
	if (threadpool == NULL) {
		/* No thread pool provided: execute function sequentially on the calling thread */
		for (size_t i = 0; i < range_i; i++) {
			for (size_t j = 0; j < range_j; j++) {
				function(argument, i, j);
			}
		}
	} else {
		/* Execute in parallel on the thread pool using linearized index */
		struct compute_2d_context context = {
			.function = function,
			.argument = argument,
			.range_j = fxdiv_init_size_t(range_j)
		};
		pthreadpool_compute_1d(threadpool, (pthreadpool_function_1d_t) compute_2d, &context, range_i * range_j);
	}
}

struct compute_2d_tiled_context {
	pthreadpool_function_2d_tiled_t function;
	void* argument;
	struct fxdiv_divisor_size_t tile_range_j;
	size_t range_i;
	size_t range_j;
	size_t tile_i;
	size_t tile_j;
};

static void compute_2d_tiled(const struct compute_2d_tiled_context* context, size_t linear_index) {
	const struct fxdiv_divisor_size_t tile_range_j = context->tile_range_j;
	const struct fxdiv_result_size_t tile_index = fxdiv_divide_size_t(linear_index, tile_range_j);
	const size_t max_tile_i = context->tile_i;
	const size_t max_tile_j = context->tile_j;
	const size_t index_i = tile_index.quotient * max_tile_i;
	const size_t index_j = tile_index.remainder * max_tile_j;
	const size_t tile_i = min(max_tile_i, context->range_i - index_i);
	const size_t tile_j = min(max_tile_j, context->range_j - index_j);
	context->function(context->argument, index_i, index_j, tile_i, tile_j);
}

void pthreadpool_compute_2d_tiled(
	pthreadpool_t threadpool,
	pthreadpool_function_2d_tiled_t function,
	void* argument,
	size_t range_i,
	size_t range_j,
	size_t tile_i,
	size_t tile_j)
{
	if (threadpool == NULL) {
		/* No thread pool provided: execute function sequentially on the calling thread */
		for (size_t i = 0; i < range_i; i += tile_i) {
			for (size_t j = 0; j < range_j; j += tile_j) {
				function(argument, i, j, min(range_i - i, tile_i), min(range_j - j, tile_j));
			}
		}
	} else {
		/* Execute in parallel on the thread pool using linearized index */
		const size_t tile_range_i = divide_round_up(range_i, tile_i);
		const size_t tile_range_j = divide_round_up(range_j, tile_j);
		struct compute_2d_tiled_context context = {
			.function = function,
			.argument = argument,
			.tile_range_j = fxdiv_init_size_t(tile_range_j),
			.range_i = range_i,
			.range_j = range_j,
			.tile_i = tile_i,
			.tile_j = tile_j
		};
		pthreadpool_compute_1d(threadpool, (pthreadpool_function_1d_t) compute_2d_tiled, &context, tile_range_i * tile_range_j);
	}
}

void pthreadpool_destroy(struct pthreadpool* threadpool) {
	/* Update threads' states */
	for (size_t tid = 0; tid < threadpool->threads_count; tid++) {
		threadpool->threads[tid].state = thread_state_shutdown;
	}

	/* Wake up the threads */
	wakeup_worker_threads(threadpool);

	/* Wait until all threads return */
	for (size_t tid = 0; tid < threadpool->threads_count; tid++) {
		pthread_join(threadpool->threads[tid].thread_object, NULL);
	}

	/* Release resources */
	pthread_mutex_destroy(&threadpool->execution_mutex);
	pthread_mutex_destroy(&threadpool->barrier_mutex);
	pthread_cond_destroy(&threadpool->barrier_condvar);
	pthread_mutex_destroy(&threadpool->state_mutex);
	pthread_cond_destroy(&threadpool->state_condvar);
	free(threadpool);
}
示例#30
0
/**
 * _mongoc_gridfs_file_refresh_page:
 *
 *    Refresh a GridFS file's underlying page. This recalculates the current
 *    page number based on the file's stream position, then fetches that page
 *    from the database.
 *
 *    Note that this fetch is unconditional and the page is queried from the
 *    database even if the current page covers the same theoretical chunk.
 *
 *
 * Side Effects:
 *
 *    file->page is loaded with the appropriate buffer, fetched from the
 *    database. If the file position is at the end of the file and on a new
 *    chunk boundary, a new page is created. If the position is far past the
 *    end of the file, _mongoc_gridfs_file_extend is responsible for creating
 *    chunks to file the gap.
 *
 *    file->n is set based on file->pos. file->error is set on error.
 */
static bool
_mongoc_gridfs_file_refresh_page (mongoc_gridfs_file_t *file)
{
   bson_t query;
   bson_t child;
   bson_t opts;
   const bson_t *chunk;
   const char *key;
   bson_iter_t iter;
   int64_t existing_chunks;
   int64_t required_chunks;

   const uint8_t *data = NULL;
   uint32_t len;

   ENTRY;

   BSON_ASSERT (file);

   file->n = (int32_t) (file->pos / file->chunk_size);

   if (file->page) {
      _mongoc_gridfs_file_page_destroy (file->page);
      file->page = NULL;
   }

   /* if the file pointer is past the end of the current file (i.e. pointing to
    * a new chunk), we'll pass the page constructor a new empty page. */
   existing_chunks = divide_round_up (file->length, file->chunk_size);
   required_chunks = divide_round_up (file->pos + 1, file->chunk_size);
   if (required_chunks > existing_chunks) {
      data = (uint8_t *) "";
      len = 0;
   } else {
      /* if we have a cursor, but the cursor doesn't have the chunk we're going
       * to need, destroy it (we'll grab a new one immediately there after) */
      if (file->cursor && !_mongoc_gridfs_file_keep_cursor (file)) {
         mongoc_cursor_destroy (file->cursor);
         file->cursor = NULL;
      }

      if (!file->cursor) {
         bson_init (&query);
         BSON_APPEND_VALUE (&query, "files_id", &file->files_id);
         BSON_APPEND_DOCUMENT_BEGIN (&query, "n", &child);
         BSON_APPEND_INT32 (&child, "$gte", file->n);
         bson_append_document_end (&query, &child);

         bson_init (&opts);
         BSON_APPEND_DOCUMENT_BEGIN (&opts, "sort", &child);
         BSON_APPEND_INT32 (&child, "n", 1);
         bson_append_document_end (&opts, &child);

         BSON_APPEND_DOCUMENT_BEGIN (&opts, "projection", &child);
         BSON_APPEND_INT32 (&child, "n", 1);
         BSON_APPEND_INT32 (&child, "data", 1);
         BSON_APPEND_INT32 (&child, "_id", 0);
         bson_append_document_end (&opts, &child);

         /* find all chunks greater than or equal to our current file pos */
         file->cursor = mongoc_collection_find_with_opts (
            file->gridfs->chunks, &query, &opts, NULL);

         file->cursor_range[0] = file->n;
         file->cursor_range[1] = (uint32_t) (file->length / file->chunk_size);

         bson_destroy (&query);
         bson_destroy (&opts);

         BSON_ASSERT (file->cursor);
      }

      /* we might have had a cursor before, then seeked ahead past a chunk.
       * iterate until we're on the right chunk */
      while (file->cursor_range[0] <= file->n) {
         if (!mongoc_cursor_next (file->cursor, &chunk)) {
            /* copy cursor error; if there's none, we're missing a chunk */
            if (!mongoc_cursor_error (file->cursor, &file->error)) {
               missing_chunk (file);
            }

            RETURN (0);
         }

         file->cursor_range[0]++;
      }

      BSON_ASSERT (bson_iter_init (&iter, chunk));

      /* grab out what we need from the chunk */
      while (bson_iter_next (&iter)) {
         key = bson_iter_key (&iter);

         if (strcmp (key, "n") == 0) {
            if (file->n != bson_iter_int32 (&iter)) {
               missing_chunk (file);
               RETURN (0);
            }
         } else if (strcmp (key, "data") == 0) {
            bson_iter_binary (&iter, NULL, &len, &data);
         } else {
            /* Unexpected key. This should never happen */
            RETURN (0);
         }
      }

      if (file->n != file->pos / file->chunk_size) {
         return 0;
      }
   }

   if (!data) {
      bson_set_error (&file->error,
                      MONGOC_ERROR_GRIDFS,
                      MONGOC_ERROR_GRIDFS_CHUNK_MISSING,
                      "corrupt chunk number %" PRId32,
                      file->n);
      RETURN (0);
   }

   file->page = _mongoc_gridfs_file_page_new (data, len, file->chunk_size);

   /* seek in the page towards wherever we're supposed to be */
   RETURN (
      _mongoc_gridfs_file_page_seek (file->page, file->pos % file->chunk_size));
}