Exemplo n.º 1
0
void MlSldaState::InitializeAssignments(bool random_init) {
  InitializeResponse();
  InitializeLength();

  LdawnState::InitializeAssignments(random_init);

  if (FLAGS_num_seed_docs > 0) {
    const gsl_vector* y = static_cast<lib_corpora::ReviewCorpus*>
      (corpus_.get())->train_ratings();
    boost::shared_ptr<gsl_permutation> sorted(gsl_permutation_alloc(y->size),
                                              gsl_permutation_free);
    boost::shared_ptr<gsl_permutation> rank(gsl_permutation_alloc(y->size),
                                            gsl_permutation_free);

    std::vector< std::vector<int> > num_seeds_used;
    num_seeds_used.resize(corpus_->num_languages());
    for (int ii = 0; ii < corpus_->num_languages(); ++ii) {
      num_seeds_used[ii].resize(num_topics_);
    }

    gsl_sort_vector_index(sorted.get(), y);
    gsl_permutation_inverse(rank.get(), sorted.get());

    // We add one for padding so we don't try to set a document to be equal to
    // the number of topics.
    double num_train = corpus_->num_train() + 1.0;
    int train_seen = 0;
    int num_docs = corpus_->num_docs();
    for (int dd = 0; dd < num_docs; ++dd) {
      MlSeqDoc* doc = corpus_->seq_doc(dd);
      int lang = doc->language();
      if (!corpus_->doc(dd)->is_test()) {
        // We don't assign to topic zero, so it can be stopwordy
        int val = (int) floor((num_topics_ - 1) *
                              rank->data[train_seen] / num_train) + 1;

        // Stop once we've used our limit of seed docs (too many leads to an
        // overfit initial state)
        if (num_seeds_used[lang][val] < FLAGS_num_seed_docs) {
          cout << "Initializing doc " << lang << " " << dd << " to " << val <<
            " score=" << truth_[dd] << endl;
          for (int jj = 0; jj < (int)topic_assignments_[dd].size(); ++jj) {
            int term = (*doc)[jj];
            const topicmod_projects_ldawn::WordPaths word =
              wordnet_->word(lang, term);
            int num_paths = word.size();
            if (num_paths > 0) {
              ChangePath(dd, jj, val, rand() % num_paths);
            } else {
              if (use_aux_topics())
                ChangeTopic(dd, jj, val);
            }
          }
          ++num_seeds_used[lang][val];
        }
        ++train_seen;
      }
    }
  }
}
Exemplo n.º 2
0
Real MultipleCurve3<Real>::GetLength (Real t0, Real t1) const
{
    assertion(mTMin <= t0 && t0 <= mTMax, "Invalid input\n");
    assertion(mTMin <= t1 && t1 <= mTMax, "Invalid input\n");
    assertion(t0 <= t1, "Invalid input\n");

    if (!mLengths)
    {
        InitializeLength();
    }

    int key0, key1;
    Real dt0, dt1;
    GetKeyInfo(t0, key0, dt0);
    GetKeyInfo(t1, key1, dt1);

    Real length;
    if (key0 < key1)
    {
        // Accumulate full-segment lengths.
        length = (Real)0;
        for (int i = key0 + 1; i < key1; ++i)
        {
            length += mLengths[i];
        }
        
        // Add on partial first segment.
        length += GetLengthKey(key0, dt0, mTimes[key0 + 1] - mTimes[key0]);
        
        // Add on partial last segment.
        length += GetLengthKey(key1, (Real)0, dt1);
    }
    else
    {
        length = GetLengthKey(key0, dt0, dt1);
    }

    return length;
}
Exemplo n.º 3
0
Real MultipleCurve3<Real>::GetTime (Real length, int iterations,
    Real tolerance) const
{
    if (!mLengths)
    {
        InitializeLength();
    }

    if (length <= (Real)0)
    {
        return mTMin;
    }

    if (length >= mAccumLengths[mNumSegments - 1])
    {
        return mTMax;
    }

    int key;
    for (key = 0; key < mNumSegments; ++key)
    {
        if (length < mAccumLengths[key])
        {
            break;
        }
    }
    if (key >= mNumSegments)
    {
        return mTimes[mNumSegments];
    }

    Real len0, len1;
    if (key == 0)
    {
        len0 = length;
        len1 = mAccumLengths[0];
    }
    else
    {
        len0 = length - mAccumLengths[key - 1];
        len1 = mAccumLengths[key] - mAccumLengths[key - 1];
    }

    // If L(t) is the length function for t in [tmin,tmax], the derivative is
    // L'(t) = |x'(t)| >= 0 (the magnitude of speed).  Therefore, L(t) is a
    // nondecreasing function (and it is assumed that x'(t) is zero only at
    // isolated points; that is, no degenerate curves allowed).  The second
    // derivative is L"(t).  If L"(t) >= 0 for all t, L(t) is a convex
    // function and Newton's method for root finding is guaranteed to
    // converge.  However, L"(t) can be negative, which can lead to Newton
    // iterates outside the domain [tmin,tmax].  The algorithm here avoids
    // this problem by using a hybrid of Newton's method and bisection.

    // Initial guess for Newton's method is dt0.
    Real dt1 = mTimes[key + 1] - mTimes[key];
    Real dt0 = dt1*len0/len1;

    // Initial root-bounding interval for bisection.
    Real lower = (Real)0, upper = dt1;

    for (int i = 0; i < iterations; ++i)
    {
        Real difference = GetLengthKey(key, (Real)0, dt0) - len0;
        if (Math<Real>::FAbs(difference) <= tolerance)
        {
            // |L(mTimes[key]+dt0)-length| is close enough to zero, report
            // mTimes[key]+dt0 as the time at which 'length' is attained.
            return mTimes[key] + dt0;
        }

        // Generate a candidate for Newton's method.
        Real dt0Candidate = dt0 - difference/GetSpeedKey(key, dt0);

        // Update the root-bounding interval and test for containment of the
        // candidate.
        if (difference > (Real)0)
        {
            upper = dt0;
            if (dt0Candidate <= lower)
            {
                // Candidate is outside the root-bounding interval.  Use
                // bisection instead.
                dt0 = ((Real)0.5)*(upper + lower);
            }
            else
            {
                // There is no need to compare to 'upper' because the tangent
                // line has positive slope, guaranteeing that the t-axis
                // intercept is smaller than 'upper'.
                dt0 = dt0Candidate;
            }
        }
        else
        {
            lower = dt0;
            if (dt0Candidate >= upper)
            {
                // Candidate is outside the root-bounding interval.  Use
                // bisection instead.
                dt0 = ((Real)0.5)*(upper + lower);
            }
            else
            {
                // There is no need to compare to 'lower' because the tangent
                // line has positive slope, guaranteeing that the t-axis
                // intercept is larger than 'lower'.
                dt0 = dt0Candidate;
            }
        }
    }

    // A root was not found according to the specified number of iterations
    // and tolerance.  You might want to increase iterations or tolerance or
    // integration accuracy.  However, in this application it is likely that
    // the time values are oscillating, due to the limited numerical
    // precision of 32-bit floats.  It is safe to use the last computed time.
    return mTimes[key] + dt0;
}