예제 #1
0
파일: linear.cpp 프로젝트: cabiling/madlib
// -----------------------------------------------------------------------
// Linear regression
// -----------------------------------------------------------------------
AnyType
linregr_transition::run(AnyType& args) {
    MutableLinRegrState state = args[0].getAs<MutableByteString>();
    if (args[1].isNull() || args[2].isNull()) { return args[0]; }
    double y = args[1].getAs<double>();
    MappedColumnVector x;
    try {
        MappedColumnVector xx = args[2].getAs<MappedColumnVector>();
        x.rebind(xx.memoryHandle(), xx.size());
    } catch (const ArrayWithNullException &e) {
        return args[0];
    }

    state << MutableLinRegrState::tuple_type(x, y);
    return state.storage();
}
AnyType
multi_response_glm_multinom_logit_transition::run(AnyType& args) {
    MutableMultiResponseGLMState state = args[0].getAs<MutableByteString>();
    if (state.terminated || args[1].isNull() || args[2].isNull()) {
        return args[0];
    }
    double y = args[1].getAs<double>();
    MappedColumnVector x;
    try {
        MappedColumnVector xx = args[2].getAs<MappedColumnVector>();
        x.rebind(xx.memoryHandle(), xx.size());
    } catch (const ArrayWithNullException &e) {
        return args[0];
    }
    if (state.empty()) {
        state.num_features = static_cast<uint16_t>(x.size());
        state.num_categories = args[4].getAs<uint16_t>();
        state.optimizer.num_coef = static_cast<uint16_t>(
                state.num_features * (state.num_categories-1));

        // MADLIB-667: GPDB limits the single array size to be 1GB, which means
        // that the size of a double array cannot be large than 134217727
        // because (134217727 * 8) / (1024 * 1024) = 1023. And solve
        // state_size = x^2 + 2^x + 6 <= 134217727 will give x <= 11584.
        uint32_t state_size = 6 +
                state.optimizer.num_coef * state.optimizer.num_coef +
                2 * state.optimizer.num_coef;
        if(state_size > 134217727){
            throw std::runtime_error(
                "The product of number of independent variables and number of "
                "categories cannot be larger than 11584.");
        }

        state.resize();
        if (!args[3].isNull()) {
            MultiResponseGLMState prev_state = args[3].getAs<ByteString>();
            state = prev_state;
            state.reset();
        }
    }
    state << MutableMultiResponseGLMState::tuple_type(x, y);
    return state.storage();
}