示例#1
0
void IndexPreTransform::train (idx_t n, const float *x)
{
    int last_untrained = 0;
    if (!index->is_trained) {
        last_untrained = chain.size();
    } else {
        for (int i = chain.size() - 1; i >= 0; i--) {
            if (!chain[i]->is_trained) {
                last_untrained = i;
                break;
            }
        }
    }
    const float *prev_x = x;
    ScopeDeleter<float> del;

    for (int i = 0; i <= last_untrained; i++) {
        if (i < chain.size()) {
            VectorTransform *ltrans = chain [i];
            if (!ltrans->is_trained)
                ltrans->train(n, prev_x);
        } else {
            index->train (n, prev_x);
        }
        if (i == last_untrained) break;

        float * xt = chain[i]->apply (n, prev_x);
        if (prev_x != x) delete prev_x;
        prev_x = xt;
        del.set(xt);
    }

    is_trained = true;
}
示例#2
0
void IndexPreTransform::train (idx_t n, const float *x)
{
    int last_untrained = 0;
    if (!index->is_trained) {
        last_untrained = chain.size();
    } else {
        for (int i = chain.size() - 1; i >= 0; i--) {
            if (!chain[i]->is_trained) {
                last_untrained = i;
                break;
            }
        }
    }
    const float *prev_x = x;
    ScopeDeleter<float> del;

    if (verbose) {
        printf("IndexPreTransform::train: training chain 0 to %d\n",
               last_untrained);
    }

    for (int i = 0; i <= last_untrained; i++) {

        if (i < chain.size()) {
            VectorTransform *ltrans = chain [i];
            if (!ltrans->is_trained) {
                if (verbose) {
                    printf("   Training chain component %d/%zd\n",
                           i, chain.size());
                    if (OPQMatrix *opqm = dynamic_cast<OPQMatrix*>(ltrans)) {
                        opqm->verbose = true;
                    }
                }
                ltrans->train (n, prev_x);
            }
        } else {
            if (verbose) {
                printf("   Training sub-index\n");
            }
            index->train (n, prev_x);
        }
        if (i == last_untrained) break;
        if (verbose) {
            printf("   Applying transform %d/%zd\n",
                   i, chain.size());
        }

        float * xt = chain[i]->apply (n, prev_x);

        if (prev_x != x) delete [] prev_x;
        prev_x = xt;
        del.set(xt);
    }

    is_trained = true;
}
示例#3
0
const float *IndexPreTransform::apply_chain (idx_t n, const float *x) const
{
    const float *prev_x = x;
    ScopeDeleter<float> del;

    for (int i = 0; i < chain.size(); i++) {
        float * xt = chain[i]->apply (n, prev_x);
        ScopeDeleter<float> del2 (xt);
        del2.swap (del);
        prev_x = xt;
    }
    del.release ();
    return prev_x;
}