Beispiel #1
0
/**
 * @brief This function is the sfunc of an aggregator computing the
 * perplexity.  
 * @param args[0]   The current state 
 * @param args[1]   The unique words in the documents
 * @param args[2]   The counts of each unique words
 * @param args[3]   The topic counts in the document
 * @param args[4]   The model (word topic counts and corpus topic
 *                  counts)
 * @param args[5]   The Dirichlet parameter for per-document topic
 *                  multinomial, i.e. alpha
 * @param args[6]   The Dirichlet parameter for per-topic word
 *                  multinomial, i.e. beta
 * @param args[7]   The size of vocabulary
 * @param args[8]   The number of topics
 * @return          The updated state 
 **/
AnyType lda_perplexity_sfunc::run(AnyType & args){
    ArrayHandle<int32_t> words = args[1].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> counts = args[2].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> topic_counts = args[3].getAs<ArrayHandle<int32_t> >();
    double alpha = args[5].getAs<double>();
    double beta = args[6].getAs<double>();
    int32_t voc_size = args[7].getAs<int32_t>();
    int32_t topic_num = args[8].getAs<int32_t>();

    if(alpha <= 0)
        throw std::invalid_argument("invalid argument - alpha");
    if(beta <= 0)
        throw std::invalid_argument("invalid argument - beta");
    if(voc_size <= 0)
        throw std::invalid_argument(
            "invalid argument - voc_size");
    if(topic_num <= 0)
        throw std::invalid_argument(
            "invalid argument - topic_num");

    if(words.size() != counts.size())
        throw std::invalid_argument(
            "dimensions mismatch: words.size() != counts.size()");
    if(__min(words) < 0 || __max(words) >= voc_size)
        throw std::invalid_argument(
            "invalid values in words");
    if(__min(counts) <= 0)
        throw std::invalid_argument(
            "invalid values in counts");

    if(topic_counts.size() != (size_t)(topic_num))
        throw std::invalid_argument(
            "invalid dimension - topic_counts.size() != topic_num");
    if(__min(topic_counts, 0, topic_num) < 0)
        throw std::invalid_argument("invalid values in topic_counts");

    MutableArrayHandle<int64_t> state(NULL);
    if(args[0].isNull()){
        if(args[4].isNull())
            throw std::invalid_argument("invalid argument - the model \
            parameter should not be null for the first call");
        ArrayHandle<int64_t> model = args[4].getAs<ArrayHandle<int64_t> >();

        if(model.size() != (size_t)((voc_size + 1) * topic_num))
            throw std::invalid_argument(
                "invalid dimension - model.size() != (voc_size + 1) * topic_num");
        if(__min(model) < 0)
            throw std::invalid_argument("invalid topic counts in model");

        state =  madlib_construct_array(NULL,
                                        static_cast<int>(model.size()) + 1,
                                        INT8TI.oid,
                                        INT8TI.len,
                                        INT8TI.byval,
                                        INT8TI.align);

        memcpy(state.ptr(), model.ptr(),  model.size() * sizeof(int64_t));
    }else{
Beispiel #2
0
/**
 * @brief This function is the finalfunc of an aggregator computing the
 * perplexity.
 * @param args[0]   The global state
 * @return          The perplexity
 **/
AnyType lda_perplexity_ffunc::run(AnyType & args){
    ArrayHandle<int64_t> state = args[0].getAs<ArrayHandle<int64_t> >();
    const double * perp = reinterpret_cast<const double *>(state.ptr() + state.size() - 1);
    return *perp;
}
Beispiel #3
0
/**
 * @brief This function is the sfunc of an aggregator computing the
 * perplexity.
 * @param args[0]   The current state
 * @param args[1]   The unique words in the documents
 * @param args[2]   The counts of each unique words
 * @param args[3]   The topic counts in the document
 * @param args[4]   The model (word topic counts and corpus topic
 *                  counts)
 * @param args[5]   The Dirichlet parameter for per-document topic
 *                  multinomial, i.e. alpha
 * @param args[6]   The Dirichlet parameter for per-topic word
 *                  multinomial, i.e. beta
 * @param args[7]   The size of vocabulary
 * @param args[8]   The number of topics
 * @return          The updated state
 **/
AnyType lda_perplexity_sfunc::run(AnyType & args){
    ArrayHandle<int32_t> words = args[1].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> counts = args[2].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> doc_topic_counts = args[3].getAs<ArrayHandle<int32_t> >();
    double alpha = args[5].getAs<double>();
    double beta = args[6].getAs<double>();
    int32_t voc_size = args[7].getAs<int32_t>();
    int32_t topic_num = args[8].getAs<int32_t>();
    size_t model64_size = static_cast<size_t>(voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t);

    if(alpha <= 0)
        throw std::invalid_argument("invalid argument - alpha");
    if(beta <= 0)
        throw std::invalid_argument("invalid argument - beta");
    if(voc_size <= 0)
        throw std::invalid_argument(
            "invalid argument - voc_size");
    if(topic_num <= 0)
        throw std::invalid_argument(
            "invalid argument - topic_num");

    if(words.size() != counts.size())
        throw std::invalid_argument(
            "dimensions mismatch: words.size() != counts.size()");
    if(__min(words) < 0 || __max(words) >= voc_size)
        throw std::invalid_argument(
            "invalid values in words");
    if(__min(counts) <= 0)
        throw std::invalid_argument(
            "invalid values in counts");

    if(doc_topic_counts.size() != (size_t)(topic_num))
        throw std::invalid_argument(
            "invalid dimension - doc_topic_counts.size() != topic_num");
    if(__min(doc_topic_counts, 0, topic_num) < 0)
        throw std::invalid_argument("invalid values in doc_topic_counts");

    MutableArrayHandle<int64_t> state(NULL);
    if (args[0].isNull()) {
        ArrayHandle<int64_t> model64 = args[4].getAs<ArrayHandle<int64_t> >();

        if (model64.size() != model64_size) {
            std::stringstream ss;
            ss << "invalid dimension: model64.size() = " << model64.size();
            throw std::invalid_argument(ss.str());
        }
        if(__min(model64) < 0) {
            throw std::invalid_argument("invalid topic counts in model");
        }

        state =  madlib_construct_array(NULL,
                                        static_cast<int>(model64.size())
                                            + topic_num
                                            + sizeof(double) / sizeof(int64_t),
                                        INT8TI.oid,
                                        INT8TI.len,
                                        INT8TI.byval,
                                        INT8TI.align);

        memcpy(state.ptr(), model64.ptr(), model64.size() * sizeof(int64_t));
        int32_t *_model = reinterpret_cast<int32_t *>(state.ptr());
        int64_t *_total_topic_counts = reinterpret_cast<int64_t *>(state.ptr() + model64.size());
        for (int i = 0; i < voc_size; i ++) {
            for (int j = 0; j < topic_num; j ++) {
                _total_topic_counts[j] += _model[i * (topic_num + 1) + j];
            }
        }
    } else {
        state = args[0].getAs<MutableArrayHandle<int64_t> >();
    }

    int32_t *model = reinterpret_cast<int32_t *>(state.ptr());
    int64_t *total_topic_counts = reinterpret_cast<int64_t *>(state.ptr() + model64_size);
    double *perp = reinterpret_cast<double *>(state.ptr() + state.size() - 1);

    int32_t n_d = 0;
    for(size_t i = 0; i < words.size(); i++){
        n_d += counts[i];
    }

    for(size_t i = 0; i < words.size(); i++){
        int32_t w = words[i];
        int32_t n_dw = counts[i];

        double sum_p = 0.0;
        for(int32_t z = 0; z < topic_num; z++){
                int32_t n_dz = doc_topic_counts[z];
                int32_t n_wz = model[w * (topic_num + 1) + z];
                int64_t n_z = total_topic_counts[z];
                sum_p += (static_cast<double>(n_wz) + beta) * (n_dz + alpha)
                            / (static_cast<double>(n_z) + voc_size * beta);
        }
        sum_p /= (n_d + topic_num * alpha);

        *perp += n_dw * log(sum_p);
    }

    return state;
}
Beispiel #4
0
/**
 * @brief This function is the sfunc for the aggregator computing the topic
 * counts. It scans the topic assignments in a document and updates the word
 * topic counts.
 * @param args[0]   The state variable, current topic counts
 * @param args[1]   The unique words in the document
 * @param args[2]   The counts of each unique word in the document
 * @param args[3]   The topic assignments in the document
 * @param args[4]   The size of vocabulary
 * @param args[5]   The number of topics
 * @return          The updated state
 **/
AnyType lda_count_topic_sfunc::run(AnyType & args)
{
    if(args[4].isNull() || args[5].isNull())
        throw std::invalid_argument("null parameter - voc_size and/or \
        topic_num is null");

    if(args[1].isNull() || args[2].isNull() || args[3].isNull())
        return args[0];

    int32_t voc_size = args[4].getAs<int32_t>();
    int32_t topic_num = args[5].getAs<int32_t>();
    if(voc_size <= 0)
        throw std::invalid_argument(
            "invalid argument - voc_size");
    if(topic_num <= 0)
        throw std::invalid_argument(
            "invalid argument - topic_num");

    ArrayHandle<int32_t> words = args[1].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> counts = args[2].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> topic_assignment = args[3].getAs<ArrayHandle<int32_t> >();
    if(words.size() != counts.size())
        throw std::invalid_argument(
            "dimensions mismatch - words.size() != counts.size()");
    if(__min(words) < 0 || __max(words) >= voc_size)
        throw std::invalid_argument(
            "invalid values in words");
    if(__min(counts) <= 0)
        throw std::invalid_argument(
            "invalid values in counts");
    if(__min(topic_assignment) < 0 || __max(topic_assignment) >= topic_num)
        throw std::invalid_argument("invalid values in topics");
    if((size_t)__sum(counts) != topic_assignment.size())
        throw std::invalid_argument(
            "dimension mismatch - sum(counts) != topic_assignment.size()");

    MutableArrayHandle<int64_t> state(NULL);
    int32_t *model;
    if(args[0].isNull()) {
        // to store a voc_size x (topic_num+1) integer matrix in
        // bigint[] (the +1 is for a flag of ceiling the count),
        // we need padding if the size is odd.
        // 1. when voc_size * (topic_num + 1) is (2n+1), gives (n+1)
        // 2. when voc_size * (topic_num + 1) is (2n), gives (n)
        int dims[1] = {static_cast<int>( (voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t) )};
        int lbs[1] = {1};
        state = madlib_construct_md_array(
            NULL, NULL, 1, dims, lbs, INT8TI.oid, INT8TI.len, INT8TI.byval,
            INT8TI.align);
        // the reason we use bigint[] because integer[] has limit on number of
        // elements and thus cannot be larger than 500MB
        model = reinterpret_cast<int32_t *>(state.ptr());
    } else {
        state = args[0].getAs<MutableArrayHandle<int64_t> >();
        model = reinterpret_cast<int32_t *>(state.ptr());
    }

    int32_t unique_word_count = static_cast<int32_t>(words.size());
    int32_t word_index = 0;
    for(int32_t i = 0; i < unique_word_count; i++){
        int32_t wordid = words[i];
        for(int32_t j = 0; j < counts[i]; j++){
            int32_t topic = topic_assignment[word_index];
            if (model[wordid * (topic_num + 1) + topic] <= 2e9) {
                model[wordid * (topic_num + 1) + topic]++;
            } else {
                model[wordid * (topic_num + 1) + topic_num] = 1;
            }
            word_index++;
        }
    }
    return state;
}
Beispiel #5
0
/**
 * @brief This function learns the topics of words in a document and is the
 * main step of a Gibbs sampling iteration. The word topic counts and
 * corpus topic counts are passed to this function in the first call and
 * then transfered to the rest calls through args.mSysInfo->user_fctx for
 * efficiency.
 * @param args[0]   The unique words in the documents
 * @param args[1]   The counts of each unique words
 * @param args[2]   The topic counts and topic assignments in the document
 * @param args[3]   The model (word topic counts and corpus topic
 *                  counts)
 * @param args[4]   The Dirichlet parameter for per-document topic
 *                  multinomial, i.e. alpha
 * @param args[5]   The Dirichlet parameter for per-topic word
 *                  multinomial, i.e. beta
 * @param args[6]   The size of vocabulary
 * @param args[7]   The number of topics
 * @param args[8]   The number of iterations (=1:training, >1:prediction)
 * @return          The updated topic counts and topic assignments for
 *                  the document
 **/
AnyType lda_gibbs_sample::run(AnyType & args)
{
    ArrayHandle<int32_t> words = args[0].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> counts = args[1].getAs<ArrayHandle<int32_t> >();
    MutableArrayHandle<int32_t> doc_topic = args[2].getAs<MutableArrayHandle<int32_t> >();
    double alpha = args[4].getAs<double>();
    double beta = args[5].getAs<double>();
    int32_t voc_size = args[6].getAs<int32_t>();
    int32_t topic_num = args[7].getAs<int32_t>();
    int32_t iter_num = args[8].getAs<int32_t>();
    size_t model64_size = static_cast<size_t>(voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t);

    if(alpha <= 0)
        throw std::invalid_argument("invalid argument - alpha");
    if(beta <= 0)
        throw std::invalid_argument("invalid argument - beta");
    if(voc_size <= 0)
        throw std::invalid_argument(
            "invalid argument - voc_size");
    if(topic_num <= 0)
        throw std::invalid_argument(
            "invalid argument - topic_num");
    if(iter_num <= 0)
        throw std::invalid_argument(
            "invalid argument - iter_num");

    if(words.size() != counts.size())
        throw std::invalid_argument(
            "dimensions mismatch: words.size() != counts.size()");
    if(__min(words) < 0 || __max(words) >= voc_size)
        throw std::invalid_argument(
            "invalid values in words");
    if(__min(counts) <= 0)
        throw std::invalid_argument(
            "invalid values in counts");

    int32_t word_count = __sum(counts);
    if(doc_topic.size() != (size_t)(word_count + topic_num))
        throw std::invalid_argument(
            "invalid dimension - doc_topic.size() != word_count + topic_num");
    if(__min(doc_topic, 0, topic_num) < 0)
        throw std::invalid_argument("invalid values in topic_count");
    if(
        __min(doc_topic, topic_num, word_count) < 0 ||
        __max(doc_topic, topic_num, word_count) >= topic_num)
        throw std::invalid_argument( "invalid values in topic_assignment");

    if (!args.getUserFuncContext()) {
        ArrayHandle<int64_t> model64 = args[3].getAs<ArrayHandle<int64_t> >();
        if (model64.size() != model64_size) {
            std::stringstream ss;
            ss << "invalid dimension: model64.size() = " << model64.size();
            throw std::invalid_argument(ss.str());
        }
        if (__min(model64) < 0) {
            throw std::invalid_argument("invalid topic counts in model");
        }

        int32_t *context =
            static_cast<int32_t *>(
                MemoryContextAllocZero(
                    args.getCacheMemoryContext(),
                    model64.size() * sizeof(int64_t)
                        + topic_num * sizeof(int64_t)));
        memcpy(context, model64.ptr(), model64.size() * sizeof(int64_t));
        int32_t *model = context;

        int64_t *running_topic_counts = reinterpret_cast<int64_t *>(
                context + model64_size * sizeof(int64_t) / sizeof(int32_t));
        for (int i = 0; i < voc_size; i ++) {
            for (int j = 0; j < topic_num; j ++) {
                running_topic_counts[j] += model[i * (topic_num + 1) + j];
            }
        }

        args.setUserFuncContext(context);
    }

    int32_t *context = static_cast<int32_t *>(args.getUserFuncContext());
    if (context == NULL) {
        throw std::runtime_error("args.mSysInfo->user_fctx is null");
    }
    int32_t *model = context;
    int64_t *running_topic_counts = reinterpret_cast<int64_t *>(
            context + model64_size * sizeof(int64_t) / sizeof(int32_t));

    int32_t unique_word_count = static_cast<int32_t>(words.size());
    for(int32_t it = 0; it < iter_num; it++){
        int32_t word_index = topic_num;
        for(int32_t i = 0; i < unique_word_count; i++) {
            int32_t wordid = words[i];
            for(int32_t j = 0; j < counts[i]; j++){
                int32_t topic = doc_topic[word_index];
                int32_t retopic = __lda_gibbs_sample(
                    topic_num, topic, doc_topic.ptr(),
                    model + wordid * (topic_num + 1),
                    running_topic_counts, alpha, beta);
                doc_topic[word_index] = retopic;
                doc_topic[topic]--;
                doc_topic[retopic]++;

                if(iter_num == 1) {
                    if (model[wordid * (topic_num + 1) + retopic] <= 2e9) {
                        running_topic_counts[topic] --;
                        running_topic_counts[retopic] ++;
                        model[wordid * (topic_num + 1) + topic]--;
                        model[wordid * (topic_num + 1) + retopic]++;
                    } else {
                        model[wordid * (topic_num + 1) + topic_num] = 1;
                    }
                }
                word_index++;
            }
        }
    }

    return doc_topic;
}
Beispiel #6
0
/**
 * @brief Get the sum of an array - for parameter checking
 * @return      The sum
 * @note The caller will ensure that ah is always non-null.
 **/
static int32_t __sum(ArrayHandle<int32_t> ah){
    const int32_t * array = ah.ptr();
    size_t size = ah.size();
    return std::accumulate(array, array + size, static_cast<int32_t>(0));
}
Beispiel #7
0
template<class T> static T __max(ArrayHandle<T> ah){
    return __max(ah, 0, ah.size());
}
AnyType vcrf_top1_label::run(AnyType& args) {

    ArrayHandle<double> mArray = args[0].getAs<ArrayHandle<double> >();
    ArrayHandle<double> rArray = args[1].getAs<ArrayHandle<double> >();
    const int32_t numLabels = args[2].getAs<int32_t>();

    if (numLabels == 0)
        throw std::invalid_argument("Number of labels cannot be zero");

    int doc_len = static_cast<int>(rArray.size() / numLabels);

    double* prev_top1_array = new double[numLabels];
    double* curr_top1_array = new double[numLabels];
    double* prev_norm_array = new double[numLabels];
    double* curr_norm_array = new double[numLabels];
    int* path = new int[doc_len*numLabels];

    memset(prev_top1_array, 0, numLabels*sizeof(double));
    memset(prev_norm_array, 0, numLabels*sizeof(double));
    memset(path, 0, doc_len*numLabels*sizeof(int));

    for(int start_pos = 0; start_pos < doc_len; start_pos++) {
        memset(curr_top1_array, 0, numLabels*sizeof(double));
        memset(curr_norm_array, 0, numLabels*sizeof(double));

        if (start_pos == 0) {
            for (int label = 0; label < numLabels; label++) {
                 curr_norm_array[label] = rArray[label] + mArray[label];
                 curr_top1_array[label] = rArray[label] + mArray[label];
            }
        } else {
            for (int curr_label = 0; curr_label < numLabels; curr_label++) {
                for (int prev_label = 0; prev_label < numLabels; prev_label++) {
                    double top1_new_score = prev_top1_array[prev_label]
                                               + rArray[start_pos*numLabels + curr_label]
                                               + mArray[(prev_label+1)*numLabels + curr_label];

                    if (start_pos == doc_len - 1)
                        top1_new_score += mArray[(numLabels+1)*numLabels + curr_label];

                    if (top1_new_score > curr_top1_array[curr_label]) {
                        curr_top1_array[curr_label] = top1_new_score;
                        path[start_pos*numLabels + curr_label] = prev_label;
                    }

                    /* calculate the probability of the best label sequence */
                    double norm_new_score = prev_norm_array[prev_label]
                                               + rArray[start_pos * numLabels + curr_label]
                                               + mArray[(prev_label+1)*numLabels + curr_label];

                    /* last token in a sentence, the end feature should be fired */
                    if (start_pos == doc_len - 1)
                        norm_new_score += mArray[(numLabels+1)*numLabels + curr_label];

                    /* The following wants to do z = log(exp(x)+exp(y)), the faster implementation is
                     *  z=min(x,y) + log(exp(abs(x-y))+1)
                     *  0.5 is for rounding
                     */
                    if (curr_norm_array[curr_label] == 0)
                        curr_norm_array[curr_label] = norm_new_score;
                    else {
                        double x = curr_norm_array[curr_label];
                        double y = norm_new_score;
                        curr_norm_array[curr_label] = std::min(x,y) +
                                static_cast<double>(log(std::exp(std::abs(y-x)/1000.0) +1)*1000.0 + 0.5);
                    }
                }
            }
        }
        for (int label = 0; label < numLabels; label++) {
            prev_top1_array[label] = curr_top1_array[label];
            prev_norm_array[label] = curr_norm_array[label];
        }
    }

    /* find the label of the last token in a sentence */
    double max_score = 0.0;
    int top1_label = 0;
    for(int label = 0; label < numLabels; label++) {
        if(curr_top1_array[label] > max_score) {
            max_score = curr_top1_array[label];
            top1_label = label;
        }
    }

    /* Define the result array with doc_len+1 elements, where the first doc_len
     * elements are used to store the best labels and the last element is used
     * to store the conditional probability of the sequence.
     */
    MutableArrayHandle<int> result(
        madlib_construct_array(
            NULL, doc_len+1, INT4TI.oid,
               INT4TI.len, INT4TI.byval, INT4TI.align));

    /* trace back to get the labels for the rest tokens in a sentence */
    result[doc_len - 1] = top1_label;
    for (int pos = doc_len - 1; pos >= 1; pos--) {
        top1_label = path[pos * numLabels + top1_label];
        result[pos-1] = top1_label;
    }

    /* compute the sum_i of log(v1[i]/1000), return (e^sum)*1000
     * used in the UDFs which needs marginalization e.g., normalization
     * the following wants to do z=log(exp(x)+exp(y)), the faster implementation is
     * z = min(x,y) + log(exp(abs(x-y))+1)
     */
    double norm_factor = 0.0;
    for (int i = 0; i < numLabels; i++) {
        if (i==0)
            norm_factor = curr_norm_array[0];
        else {
            double x = curr_norm_array[i];
            double y = norm_factor;
            norm_factor = std::min(x,y) + static_cast<double>(log(exp(std::abs(y-x)/1000.0) +1)*1000.0+0.5);
         }
    }

    /* calculate the conditional probability.
     * To convert the probability into integer, firstly,let it multiply 1000000, then later make the product divided by 1000000
     * to get the real conditional probability
     */
    result[doc_len] = static_cast<int>(std::exp((max_score - norm_factor)/1000.0)*1000000);

    delete[] prev_top1_array;
    delete[] curr_top1_array;
    delete[] prev_norm_array;
    delete[] curr_norm_array;
    delete[] path;

    return result;
}
Beispiel #9
0
/**
 * @brief This function is the sfunc for the aggregator computing the topic
 * counts. It scans the topic assignments in a document and updates the word
 * topic counts.
 * @param args[0]   The state variable, current topic counts
 * @param args[1]   The unique words in the document
 * @param args[2]   The counts of each unique word in the document
 * @param args[3]   The topic assignments in the document
 * @param args[4]   The size of vocabulary
 * @param args[5]   The number of topics 
 * @return          The updated state
 **/
AnyType lda_count_topic_sfunc::run(AnyType & args)
{
    if(args[4].isNull() || args[5].isNull())
        throw std::invalid_argument("null parameter - voc_size and/or \
        topic_num is null");

    if(args[1].isNull() || args[2].isNull() || args[3].isNull()) 
        return args[0];

    int32_t voc_size = args[4].getAs<int32_t>();
    int32_t topic_num = args[5].getAs<int32_t>();
    if(voc_size <= 0)
        throw std::invalid_argument(
            "invalid argument - voc_size");
    if(topic_num <= 0)
        throw std::invalid_argument(
            "invalid argument - topic_num");

    ArrayHandle<int32_t> words = args[1].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> counts = args[2].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> topic_assignment = args[3].getAs<ArrayHandle<int32_t> >();
    if(words.size() != counts.size())
        throw std::invalid_argument(
            "dimensions mismatch - words.size() != counts.size()");
    if(__min(words) < 0 || __max(words) >= voc_size)
        throw std::invalid_argument(
            "invalid values in words");
    if(__min(counts) <= 0)
        throw std::invalid_argument(
            "invalid values in counts");
    if(__min(topic_assignment) < 0 || __max(topic_assignment) >= topic_num)
        throw std::invalid_argument("invalid values in topics");
    if((size_t)__sum(counts) != topic_assignment.size())
        throw std::invalid_argument(
            "dimension mismatch - sum(counts) != topic_assignment.size()");

    MutableArrayHandle<int64_t> state(NULL);
    if(args[0].isNull()){
        int dims[2] = {voc_size + 1, topic_num};
        int lbs[2] = {1, 1};
        state = madlib_construct_md_array(
            NULL, NULL, 2, dims, lbs, INT8TI.oid, INT8TI.len, INT8TI.byval,
            INT8TI.align);
    } else {
        state = args[0].getAs<MutableArrayHandle<int64_t> >();
    }

    int32_t unique_word_count = static_cast<int32_t>(words.size());
    int32_t word_index = 0;
    for(int32_t i = 0; i < unique_word_count; i++){
        int32_t wordid = words[i];
        for(int32_t j = 0; j < counts[i]; j++){
            int32_t topic = topic_assignment[word_index];
            state[wordid * topic_num + topic]++;
            state[voc_size * topic_num + topic]++;
            word_index++;
        }
    }

    return state;
}
Beispiel #10
0
/**
 * @brief This function learns the topics of words in a document and is the
 * main step of a Gibbs sampling iteration. The word topic counts and
 * corpus topic counts are passed to this function in the first call and
 * then transfered to the rest calls through args.mSysInfo->user_fctx for
 * efficiency. 
 * @param args[0]   The unique words in the documents
 * @param args[1]   The counts of each unique words
 * @param args[2]   The topic counts and topic assignments in the document
 * @param args[3]   The model (word topic counts and corpus topic
 *                  counts)
 * @param args[4]   The Dirichlet parameter for per-document topic
 *                  multinomial, i.e. alpha
 * @param args[5]   The Dirichlet parameter for per-topic word
 *                  multinomial, i.e. beta
 * @param args[6]   The size of vocabulary
 * @param args[7]   The number of topics
 * @param args[8]   The number of iterations (=1:training, >1:prediction)
 * @return          The updated topic counts and topic assignments for
 *                  the document
 **/
AnyType lda_gibbs_sample::run(AnyType & args)
{
    ArrayHandle<int32_t> words = args[0].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> counts = args[1].getAs<ArrayHandle<int32_t> >();
    MutableArrayHandle<int32_t> doc_topic = args[2].getAs<MutableArrayHandle<int32_t> >();
    double alpha = args[4].getAs<double>();
    double beta = args[5].getAs<double>();
    int32_t voc_size = args[6].getAs<int32_t>();
    int32_t topic_num = args[7].getAs<int32_t>();
    int32_t iter_num = args[8].getAs<int32_t>();

    if(alpha <= 0)
        throw std::invalid_argument("invalid argument - alpha");
    if(beta <= 0)
        throw std::invalid_argument("invalid argument - beta");
    if(voc_size <= 0)
        throw std::invalid_argument(
            "invalid argument - voc_size");
    if(topic_num <= 0)
        throw std::invalid_argument(
            "invalid argument - topic_num");
    if(iter_num <= 0)
        throw std::invalid_argument(
            "invalid argument - iter_num");

    if(words.size() != counts.size())
        throw std::invalid_argument(
            "dimensions mismatch: words.size() != counts.size()");
    if(__min(words) < 0 || __max(words) >= voc_size)
        throw std::invalid_argument(
            "invalid values in words");
    if(__min(counts) <= 0)
        throw std::invalid_argument(
            "invalid values in counts");

    int32_t word_count = __sum(counts);
    if(doc_topic.size() != (size_t)(word_count + topic_num))
        throw std::invalid_argument(
            "invalid dimension - doc_topic.size() != word_count + topic_num");
    if(__min(doc_topic, 0, topic_num) < 0)
        throw std::invalid_argument("invalid values in topic_count");
    if(
        __min(doc_topic, topic_num, word_count) < 0 ||
        __max(doc_topic, topic_num, word_count) >= topic_num)
        throw std::invalid_argument( "invalid values in topic_assignment");

    if (!args.getUserFuncContext())
    {
        if(args[3].isNull())
            throw std::invalid_argument("invalid argument - the model \
            parameter should not be null for the first call");
        ArrayHandle<int64_t> model = args[3].getAs<ArrayHandle<int64_t> >();
        if(model.size() != (size_t)((voc_size + 1) * topic_num))
            throw std::invalid_argument(
                "invalid dimension - model.size() != (voc_size + 1) * topic_num");
        if(__min(model) < 0)
            throw std::invalid_argument("invalid topic counts in model");

        int64_t * state = 
            static_cast<int64_t *>(
                MemoryContextAllocZero(
                    args.getCacheMemoryContext(), 
                    model.size() * sizeof(int64_t)));
        memcpy(state, model.ptr(), model.size() * sizeof(int64_t));
        args.setUserFuncContext(state);
    }

    int64_t * state = static_cast<int64_t *>(args.getUserFuncContext());
    if(NULL == state){
        throw std::runtime_error("args.mSysInfo->user_fctx is null");
    }

    int32_t unique_word_count = static_cast<int32_t>(words.size());
    for(int32_t it = 0; it < iter_num; it++){
        int32_t word_index = topic_num;
        for(int32_t i = 0; i < unique_word_count; i++) {
            int32_t wordid = words[i];
            for(int32_t j = 0; j < counts[i]; j++){
                int32_t topic = doc_topic[word_index];
                int32_t retopic = __lda_gibbs_sample(
                    topic_num, topic, doc_topic.ptr(), 
                    state + wordid * topic_num, 
                    state + voc_size * topic_num, alpha, beta);
                doc_topic[word_index] = retopic;
                doc_topic[topic]--;
                doc_topic[retopic]++;

                if(iter_num == 1){
                    state[voc_size * topic_num + topic]--;
                    state[voc_size * topic_num + retopic]++;
                    state[wordid * topic_num + topic]--;
                    state[wordid * topic_num + retopic]++;
                }
                word_index++;
            }
        }
    }
    
    return doc_topic;
}