int32_t LightDocSampler::OldProposalFreshSampleInfer(LDADocument *doc)
    {

        DocInit(doc);
        int num_token = doc->size();
        int32_t &cursor = doc->get_cursor();

        int32_t token_sweeped = 0;
        cursor = 0;

        while (cursor < num_token)
        {
            ++token_sweeped;

            int32_t w = doc->Word(cursor);
            int32_t s = doc->Topic(cursor);            // old topic

            int t = Sample2WordFirstInfer(doc, w, s, s);    // new topic

            if (s != t)
            {
                doc->SetTopic(cursor, t);
                doc_topic_counter_.inc(s, -1);
                doc_topic_counter_.inc(t, 1);
            }
            cursor++;
        }
        return token_sweeped;
    }
Example #2
0
	void LightDocSampler::InferOneDoc(LDADocument* doc, ModelSlice& word_topic_table,
		petuum::ClientSummaryRow& summary_row, AliasSlice& alias_table) 
	{
		DocInit(doc);
		int num_token = doc->size();

		int32_t slice_id = word_topic_table.SliceId();
		int32_t slice_last_word = word_topic_table.LastWord();

		int32_t& cursor = doc->get_cursor();
		if (slice_id == 0) cursor = 0;
		for (; cursor != doc->size(); ++cursor) {
			int32_t word = doc->Word(cursor);
			if (word > slice_last_word)
				break;

			int32_t old_topic = doc->Topic(cursor);
			int32_t new_topic = InferWordFirst(doc, word, old_topic, old_topic,
				word_topic_table, summary_row, alias_table);

			if (old_topic != new_topic) {
				doc_topic_counter_.inc(old_topic, -1);
				doc_topic_counter_.inc(new_topic, 1);
				doc->SetTopic(cursor, new_topic);
			}
		}
	}
Example #3
0
	int32_t LightDocSampler::SampleOneDoc(LDADocument *doc,
		ModelSlice& word_topic_table,
		petuum::ClientSummaryRow& summary_row,
		AliasSlice& alias_table,
		std::vector<std::unique_ptr<petuum::DeltaArray>>& word_topic_delta_vec,
		petuum::SummaryDelta& summary_delta)
	{
		DocInit(doc);
		int num_token = doc->size();
		int32_t num_sampling = 0;
		int32_t num_sampling_changed = 0;
		int32_t slice_id = word_topic_table.SliceId();
		int32_t slice_last_word = word_topic_table.LastWord();
		int32_t& cursor = doc->get_cursor();
		if (slice_id == 0) cursor = 0;
		for (; cursor != doc->size(); ++cursor) {

			int32_t word = doc->Word(cursor);

			if (word > slice_last_word)
				break;

			++num_sampling;
			++num_sampling_;
			int32_t old_topic = doc->Topic(cursor);
			int32_t new_topic = Sample2WordFirst(doc, word, old_topic, old_topic,
				word_topic_table, summary_row, alias_table);
			if (old_topic != new_topic) {
				int32_t shard_id = word % word_topic_delta_vec.size();
				word_topic_delta_vec[shard_id]->Update(word, old_topic, -1);
				doc_topic_counter_.inc(old_topic, -1);
				summary_delta.Update(old_topic, -1);

				word_topic_delta_vec[shard_id]->Update(word, new_topic, 1);
				doc_topic_counter_.inc(new_topic, 1);
				summary_delta.Update(new_topic, 1);

				doc->SetTopic(cursor, new_topic);
				++num_sampling_changed_;
				++num_sampling_changed;
			}
		}
		return num_sampling;
	}
    int32_t LightDocSampler::OldProposalFreshSample(LDADocument *doc)
    {
        DocInit(doc);
        int num_token = doc->size();
        int32_t &cursor = doc->get_cursor();

        int32_t token_sweeped = 0;
        cursor = 0;

        while (cursor < num_token)
        {
            ++token_sweeped;

            int32_t w = doc->Word(cursor);
            int32_t s = doc->Topic(cursor);            // old topic

            int t = Sample2WordFirst(doc, w, s, s);    // new topic

            if (s != t)
            {
                word_topic_delta wtd;
                int32_t shard_id = w % num_threads_;
                wtd.word = w;
                wtd.topic = s;
                wtd.delta = -1;
                word_topic_delta_[shard_id].push_back(wtd);

                wtd.topic = t;
                wtd.delta = +1;
                word_topic_delta_[shard_id].push_back(wtd);

                --delta_summary_row_[s];
                ++delta_summary_row_[t];

                doc->SetTopic(cursor, t);
                doc_topic_counter_.inc(s, -1);
                doc_topic_counter_.inc(t, 1);
            }
            cursor++;
        }
        return token_sweeped;
    }