static error_aggregator map(icontext_type& context, const graph_type::edge_type& edge) { error_aggregator agg; if(edge.data().role == edge_data::TRAIN) { agg.train_error = extract_l2_error(edge); agg.ntrain = 1; } else if(edge.data().role == edge_data::VALIDATE) { agg.validation_error = extract_l2_error(edge); agg.nvalidation = 1; } return agg; }
map_join_pair collect_map (const graph_type::vertex_type& center, graph_type::edge_type& edge, const graph_type::vertex_type& other) { map_join_pair ret; if (edge.data().role == edge_data::TRAIN) { ret.first.data[other.id()] = edge.data().obs; // save the old rating } else { // use prediction double pred = center.data().factor.dot(other.data().factor); ret.second.data[other.id()] = pred; // save the prediction } return ret; }
/** * \brief Given an edge compute the error associated with that edge */ double extract_l2_error(const graph_type::edge_type & edge) { const double pred = edge.source().data().factor.dot(edge.target().data().factor); return (edge.data().obs - pred) * (edge.data().obs - pred); } // end of extract_l2_error