Exemple #1
0
void standardize_data(dmatrix *X, const double *b, double **average, 
                      double **stddev, double **acol, double **arow)
{
    int i, n, m, nz;
    double *avg, *std, *ac, *ar;

    n   = X->n;
    m   = X->m;
    nz  = X->nz;

    avg = malloc(n*sizeof(double));
    std = malloc(n*sizeof(double));
    dmat_colavg(X, avg);
    dmat_colstd(X, avg, std);

    for (i = 0; i < n; i++)
        if (std[i] < 1.0e-20) std[i] = 1;

    if ( nz >= 0)
    {
        ar  = malloc(n*sizeof(double));
        ac  = malloc(m*sizeof(double));

        /* X := diag(b)*X*inv(diag(std)) */
        dmat_diagscale(X, b, FALSE, std, TRUE);

        /* ac = diag(b)*1 */
        if (b != NULL) dmat_vcopy(m,   b, ac);
        else           dmat_vset (m, 1.0, ac);

        /* ar = avg^T*inv(diag(std)) */
        dmat_elemdivi(n, avg, std, ar);
    }
    else
    {
        ar  = NULL;
        ac  = NULL;

        for (i = 0; i < m; i++)
        {
            int j;
            for (j = 0; j < n; j++)
            {
                X->val[i*n+j] -= avg[j];
            }
        }
        /* X = diag(b)*X*inv(diag(std)) */
        dmat_diagscale(X, b, FALSE, std, TRUE);
    }
    if (average != NULL) *average = avg; else free(avg);
    if (stddev  != NULL) *stddev  = std; else free(std);
    if (acol    != NULL) *acol    = ac;  else free(ac);
    if (arow    != NULL) *arow    = ar;  else free(ar);
}
int l1_logreg_train(dmatrix *X, double *b, double lambda, train_opts to,
                    double *initial_x, double *initial_t, double *sol,
                    int *total_ntiter, int *total_pcgiter)
{
    /* problem data */
    problem_data_t  prob;
    variables_t     vars;

    dmatrix *matX1;     /* matX1 = diag(b)*X_std */
    dmatrix *matX2;     /* matX2 = X_std.^2 (only for pcg) */
    double *ac, *ar;
    double *avg_x, *std_x;

    int m, n, ntiter, pcgiter, status;
    double pobj, dobj, gap;
    double t, s, maxAnu;

    double *g,  *h,  *z,  *expz, *expmz;
    double *x,  *v,  *w,  *u;
    double *dx, *dv, *dw, *du;
    double *gv, *gw, *gu, *gx;
    double *d1, *d2, *Aw;

    /* pcg variables */
    pcg_status_t pcgstat;
    adata_t adata;
    mdata_t mdata;
    double *precond;

    /* temporary variables */
    double *tm1, *tn1, *tn2, *tn3, *tn4, *tx1;

    /* temporary variables for dense case (cholesky) */
    dmatrix *B;     /* m x n (or m x n) */
    dmatrix *BB;    /* n x n (or m x m) */

    char format_buf[PRINT_BUF_SIZE];

#if INTERNAL_PLOT
    dmatrix *internal_plot;
    dmat_new_dense(&internal_plot, 3, MAX_NT_ITER);
    memset(internal_plot->val,0,sizeof(double)*3*MAX_NT_ITER);
    /* row 1: cum_nt_iter, row 2: cum_pcg_iter, row 3: duality gap */
#endif

    p2func_progress print_progress = NULL;

    /*
     *  INITIALIZATION
     */
    s       =  1.0;
    pobj    =  DBL_MAX;
    dobj    = -DBL_MAX;
    pcgiter =  0;
    matX1   =  NULL;
    matX2   =  NULL;

    init_pcg_status(&pcgstat);

    dmat_duplicate(X, &matX1);
    dmat_copy(X, matX1);

    m = matX1->m;
    n = matX1->n;

    if (to.sflag == TRUE)
    {
        /* standardize_data not only standardizes the data,
           but also multiplies diag(b). */
        standardize_data(matX1, b, &avg_x, &std_x, &ac, &ar);
    }
    else
    {
        /* matX1 = diag(b)*X */
        dmat_diagscale(matX1, b, FALSE, NULL, TRUE);
        avg_x = std_x = ac = ar = NULL;
    }

    if (matX1->nz >= 0)                /* only for pcg */
    {
        dmat_elemAA(matX1, &matX2);
    }
    else
    {
        matX2 = NULL;
    }

    set_problem_data(&prob, matX1, matX2, ac, ar, b, lambda, avg_x, std_x);

    create_variables(&vars, m, n);
    get_variables(&vars, &x, &v, &w, &u, &dx, &dv, &dw, &du, &gx, &gv,
                  &gw, &gu, &g, &h, &z, &expz, &expmz, &d1, &d2, &Aw);

    allocate_temporaries(m, n, (matX1->nz >= 0),
                         &tm1, &tn1, &tn2, &tn3, &tn4, &tx1, &precond, &B, &BB);

    if (initial_x == NULL)
    {
        dmat_vset(1, 0.0, v);
        dmat_vset(n, 0.0, w);
        dmat_vset(n, 1.0, u);
        dmat_vset(n+n+1, 0, dx);
        t = min(max(1.0, 1.0 / lambda), 2.0 * n / ABSTOL);
    }
    else
    {
        dmat_vcopy(n+n+1, initial_x, x);
        dmat_vset(n+n+1, 0, dx);
        t = *initial_t;
    }

    set_adata(&adata, matX1, ac, ar, b, h, d1, d2);
    set_mdata(&mdata, m, n, precond);

    /* select printing function and format according to
           verbose level and method type (pcg/direct) */

    if (to.verbose_level>=2) init_progress((matX1->nz >= 0), to.verbose_level,
                              format_buf, &print_progress);

    /*** MAIN LOOP ************************************************************/

    for (ntiter = 0; ntiter < MAX_NT_ITER; ntiter++)
    {
        /*
         *  Sets v as the optimal value of the intercept.
         */
        dmat_yAmpqx(matX1, ac, ar, w, Aw);
        optimize_intercept(v, z, expz, expmz, tm1, b, Aw, m);

        /*
         *  Constructs dual feasible point nu.
         */
        fprimes(m, expz, expmz, g, h);

        /* partially computes the gradient of phi.
           the rest part of the gradient will be completed while computing 
           the search direction. */

        gv[0] = dmat_dot(m, b, g);              /* gv = b'*g */
        dmat_yAmpqTx(matX1, ac, ar, g, gw);     /* gw = A'*g */

        dmat_waxpby(m, -1, g, 0, NULL, tm1);    /* nu = -g   */
        maxAnu = dmat_norminf(n, gw);           /* max(A'*nu) */

        if (maxAnu > lambda)
            dmat_waxpby(m, lambda / maxAnu, tm1, 0.0, NULL, tm1);

        /*
         *  Evaluates duality gap.
         */
        pobj = logistic_loss2(m,z,expz,expmz)/m + lambda*dmat_norm1(n,w);
        dobj = max(nentropy(m, tm1) / m, dobj);
        gap  = pobj - dobj;

#if INTERNAL_PLOT
        internal_plot->val[0*MAX_NT_ITER+ntiter] = (double)ntiter;
        internal_plot->val[1*MAX_NT_ITER+ntiter] = (double)pcgiter;
        internal_plot->val[2*MAX_NT_ITER+ntiter] = gap;
#endif
        if (to.verbose_level>=2)
        {
            (*print_progress)(format_buf, ntiter, gap, pobj, dobj, s, t,
                              pcgstat.flag, pcgstat.relres, pcgstat.iter);
        }

        /*
         *  Quits if gap < tolerance.
         */
        if (gap < to.tolerance ) /***********************************************/
        {
            if (sol != NULL)
            {
                /* trim solution */
                int i;
                double lambda_threshold;

                lambda_threshold = to.ktolerance*lambda;
                sol[0] = x[0];

                for (i = 0; i < n; i++)
                {
                    sol[i+1] = (fabs(gw[i])>lambda_threshold)? x[i+1] : 0.0;
                }
                /* if standardized, sol = coeff/std */
                if (to.sflag == TRUE && to.cflag == FALSE)
                {
                    dmat_elemdivi(n, sol+1, std_x, sol+1);
                    sol[0] -= dmat_dot(n, avg_x, sol+1);
                }
            }

            if (initial_x != NULL)
            {
                dmat_vcopy(n+n+1, x, initial_x);
                *initial_t = t;
            }

            if (total_pcgiter) *total_pcgiter = pcgiter;
            if (total_ntiter ) *total_ntiter  = ntiter;

            /* free memory */
            free_variables(&vars);
            free_temporaries(tm1, tn1, tn2, tn3, tn4, tx1, precond, B, BB);
            free_problem_data(&prob);

#if INTERNAL_PLOT
            write_mm_matrix("internal_plot",internal_plot,"",TYPE_G);
#endif
            return STATUS_SOLUTION_FOUND;

        } /********************************************************************/

        /*
         *  Updates t
         */
        if (s >= 0.5)
        {
            t = max(min(2.0 * n * MU / gap, MU * t), t);
        }
        else if (s < 1e-5)
        {
            t = 1.1*t;
        }

        /*
         *  Computes search direction.
         */
        if (matX1->nz >= 0)
        {
            /* pcg */
            compute_searchdir_pcg(&prob, &vars, t, s, gap, &pcgstat, &adata, 
                                  &mdata, precond, tm1, tn1, tx1);
            pcgiter += pcgstat.iter;
        }
        else
        {
            /* direct */
            if (n > m)
            {
                /* direct method for n > m, SMW */
                compute_searchdir_chol_fat(&prob, &vars, t, B, BB,
                                           tm1, tn1, tn2, tn3, tn4);
            }
            else
            {
                /* direct method for n <= m */
                compute_searchdir_chol_thin(&prob, &vars, t, B, BB,
                                            tm1, tn1, tn2);
            }
        }

        /*
         *  Backtracking linesearch & update x = (v,w,u) and z.
         */
        s = backtracking_linesearch(&prob, &vars, t, tm1, tx1);
        if (s < 0) break; /* BLS error */
    }
    /*** END OF MAIN LOOP *****************************************************/

    /*  Abnormal termination */
    if (s < 0)
    {
        status = STATUS_MAX_LS_ITER_EXCEEDED;
    }
    else /* if (ntiter == MAX_NT_ITER) */
    {
        status = STATUS_MAX_NT_ITER_EXCEEDED;
    }

    if (sol != NULL)
    {
        /* trim solution */
        int i;
        double lambda_threshold;

        lambda_threshold = to.ktolerance*lambda;
        sol[0] = x[0];

        for (i = 0; i < n; i++)
        {
            sol[i+1] = (fabs(gw[i])>lambda_threshold)? x[i+1] : 0.0;
        }
        /* if standardized, sol = coeff/std */
        if (to.sflag == TRUE && to.cflag == FALSE)
        {
            dmat_elemdivi(n, sol+1, std_x, sol+1);
            sol[0] -= dmat_dot(n, avg_x, sol+1);
        }
    }

    if (initial_x != NULL)
    {
        dmat_vcopy(n+n+1, x, initial_x);
        *initial_t = t;
    }
    if (total_pcgiter) *total_pcgiter = pcgiter;
    if (total_ntiter ) *total_ntiter  = ntiter;

    /* free memory */
    free_variables(&vars);
    free_temporaries(tm1, tn1, tn2, tn3, tn4, tx1, precond, B, BB);
    free_problem_data(&prob);

#if INTERNAL_PLOT
    write_mm_matrix("internal_plot",internal_plot,"",TYPE_G);
#endif
    return status;
}