예제 #1
0
파일: srw.cpp 프로젝트: dmitrime/srw
    double Fw(vector<double> &drv)
    {
        double norm = dot_product(wvec, wvec);

        for (unsigned w = 0; w < drv.size(); w++)
            drv[w] += 2*wvec[w];

        double sum = 0.0;
        for (unsigned s = 0; s < subgraphs.size(); s++)
        {
            Subgraph *sg = subgraphs[s];
            Params *params = parameters[s];
            PowerMethod *power = powers[s];

            params->set_wvec(wvec);

            sg->recompute_scores(*params);
            params->recalculate_derivs();

            power->pers_pagerank();
            power->derivatives();

            double sum_p = 0.0;
            vector<double> sum_d(wvec.size(), 0.0);
            candidate_sums(s, sum_p, sum_d);

            double loss = 0.0;
            for (unsigned d = 0; d < sg->positive.size(); d++)
            {
                unsigned pos = sg->positive[d];
                double p_d = power->get_pagerank(pos);
                for (unsigned l = 0; l < sg->negative.size(); l++)
                {
                    unsigned neg = sg->negative[l];
                    double p_l = power->get_pagerank(neg);

                    double diff = p_l/sum_p - p_d/sum_p;
                    loss += hloss(diff);
                    for (unsigned w = 0; w < wvec.size(); w++)
                        drv[w] += dhloss(diff) * (
                                    (power->get_derivative(neg, w)*sum_p - p_l*sum_d[w])/(sum_p*sum_p)
                                    - 
                                    (power->get_derivative(pos, w)*sum_p - p_d*sum_d[w])/(sum_p*sum_p)
                                  );
                }
            }
            sum += loss;
        }
        return norm + sum;
    }
예제 #2
0
파일: srw.cpp 프로젝트: dmitrime/srw
        void run(Graph *graph, Users_data *profile, vector<user_id>& users, const char* outfile)
        {
            for (unsigned u = 0; u < users.size(); u++)
            {
                user_id id = users[u];
                Subgraph* sub = new Subgraph(graph, id, TIMEPOINT);
                Params* prm   = new Params(wvec, sub->subgraph, sub->mutual);
                sub->recompute_scores(*prm);

                PowerMethod* power = new PowerMethod(sub, prm, wvec.size());
                power->pers_pagerank();
                evaluate(power);

                // output after every 5000 users
                if (u % 5000 == 0)
                    report(outfile);

                delete power;
                delete prm;
                delete sub;
            }
            report(outfile);
        }