Ejemplo n.º 1
0
 void euclidean_module<T1,T2,Tstate1,Tstate2>::
 fprop(Tstate1 &in1, Tstate2 &label, Tstate1 &energy) {
   idx<T1> target = targets.select(0, label.x.get());
   idx_copy(target, in2.x);
   // squared distance between in1 and target
   idx_sqrdist(in1.x, in2.x, energy.x);
   idx_dotc(energy.x, 0.5, energy.x); // multiply by .5
 }
Ejemplo n.º 2
0
void class_answer<T, Tds1, Tds2>::fprop1(idx<T> &in, idx<T> &out)
{
    // resize out if necessary
    idxdim d(in);

    d.setdim(0, 2); // 2 outputs per pixel: class,confidence
    idx<T> outx = out;
    idx<T> inx = in;
    if (resize_output)
    {
        if (d != out.get_idxdim())
        {
            out.resize(d);
            outx = out;
        }
    }
    else   // if not resizing, narrow to the number of targets
    {
        if (outx.dim(0) != targets.dim(0))
            outx = outx.narrow(0, targets.dim(0), 0);
    }
    // apply tanh if required
    if (apply_tanh)
    {
        mtanh.fprop1(in, tmp);
        inx = tmp;
    }
    // loop on features (dimension 0) to set class and confidence
    int classid;
    T conf, max2 = 0;
    idx_1loop2(ii, inx, T, oo, outx, T, {
                   if (binary_target)
                   {
                       T t0 = targets.gget(0);
                       T t1 = targets.gget(1);
                       T a = ii.gget();
                       if (std::fabs((double)a - t0) < std::fabs((double)a - t1))
                       {
                           oo.set((T)0, 0); // class 0
                           oo.set((T)(2 - std::fabs((double)a - t0)) / 2, 1); // conf
                       }
                       else
                       {
                           oo.set((T)1, 0); // class 1
                           oo.set((T)(2 - std::fabs((double)a - t1)) / 2, 1); // conf
                       }
                   }
                   else if (single_output >= 0)
                   {
                       oo.set((T)single_output, 0); // all answers are the same class
                       oo.set((T)((ii.get(single_output) - target_min) / target_range), 1);
                   }
                   else // 1-of-n target
                   { // set class answer
                       if (force_class >= 0) classid = force_class;
                       else classid = idx_indexmax(ii);
                       oo.set((T)classid, 0);
                       // set confidence
                       intg p;
                       bool ini = false;
                       switch (conf_type)
                       {
                       case confidence_sqrdist: // squared distance to target
                           target = targets.select(0, classid);
                           conf = (T)(1.0 - ((idx_sqrdist(target, ii) - conf_shift)
                                             / conf_ratio));
                           oo.set(conf, 1);
                           break;
                       case confidence_single: // simply return class' out (normalized)
                           conf = (T)((ii.get(classid) - conf_shift) / conf_ratio);
                           oo.set(conf, 1);
                           break;
                       case confidence_max: // distance with 2nd max answer
                           conf = std::max(target_min, std::min(target_max, ii.get(classid)));
                           for (p = 0; p < ii.dim(0); ++p)
                           {
                               if (p != classid)
                               {
                                   if (!ini)
                                   {
                                       max2 = ii.get(p);
                                       ini = true;
                                   }
                                   else
                                   {
                                       if (ii.get(p) > max2)
                                           max2 = ii.get(p);
                                   }
                               }
                           }

                           max2 = std::max(target_min, std::min(target_max, max2));
                           oo.set((T)(((conf - max2) - conf_shift) / conf_ratio), 1);
                           break;
                       default:
                           eblerror("confidence type " << conf_type << " undefined");
                       }
                   }
               });