void GNaiveBayes::predictDistribution(const GVec& in, GPrediction* out) { if(m_nSampleCount <= 0) throw Ex("You must call train before you call eval"); for(size_t n = 0; n < m_pRelLabels->size(); n++) m_pOutputs[n]->eval(in.data(), &out[n], m_equivalentSampleSize); }
void GNaiveBayes::predict(const GVec& in, GVec& out) { if(m_nSampleCount <= 0) throw Ex("You must call train before you call eval"); for(size_t n = 0; n < m_pRelLabels->size(); n++) out[n] = m_pOutputs[n]->predict(in.data(), m_equivalentSampleSize, &m_rand); }
// virtual void GNaiveInstance::trainIncremental(const GVec& pIn, const GVec& pOut) { if(!m_pHeap) m_pHeap = new GHeap(1024); double* pOutputs = (double*)m_pHeap->allocAligned(sizeof(double) * m_pRelLabels->size()); GVec::copy(pOutputs, pOut.data(), m_pRelLabels->size()); for(size_t i = 0; i < m_pRelFeatures->size(); i++) { if(pIn[i] != UNKNOWN_REAL_VALUE) m_pAttrs[i]->addInstance(pIn[i], pOutputs); } }
// virtual void GWag::trainInner(const GMatrix& features, const GMatrix& labels) { GNeuralNetLearner* pTemp = NULL; std::unique_ptr<GNeuralNetLearner> hTemp; size_t weights = 0; GVec pWeightBuf; GVec pWeightBuf2; for(size_t i = 0; i < m_models; i++) { m_pNN->train(features, labels); if(pTemp) { // Average m_pNN with pTemp if(!m_noAlign) m_pNN->nn().align(pTemp->nn()); pTemp->nn().weightsToVector(pWeightBuf.data()); m_pNN->nn().weightsToVector(pWeightBuf2.data()); pWeightBuf *= (double(i) / (i + 1)); pWeightBuf.addScaled(1.0 / (i + 1), pWeightBuf2); pTemp->nn().vectorToWeights(pWeightBuf.data()); } else { // Copy the m_pNN GDom doc; GDomNode* pNode = m_pNN->serialize(&doc); GLearnerLoader ll; pTemp = new GNeuralNetLearner(pNode); hTemp.reset(pTemp); weights = pTemp->nn().weightCount(); pWeightBuf.resize(weights); pWeightBuf2.resize(weights); } } pTemp->nn().weightsToVector(pWeightBuf.data()); m_pNN->nn().vectorToWeights(pWeightBuf.data()); }
// virtual void GNaiveBayes::trainIncremental(const GVec& in, const GVec& out) { for(size_t n = 0; n < m_pRelLabels->size(); n++) m_pOutputs[n]->AddTrainingSample(in.data(), (int)out[n]); m_nSampleCount++; }