Evaluator::Evaluator( const FunctionPtr& evaluationFunction, const std::vector<ProgressWriterPtr>& progressWriters, bool initializeCombined) : m_evaluationFunction(evaluationFunction), m_aggregatedTestEvalCriterionValue(std::make_shared<Accumulator>()), m_progressWriters(progressWriters.begin(), progressWriters.end()) { // By default we set the number of threads to hardware concurrency. if (!Internal::MaxNumCPUThreadsSet()) SetMaxNumCPUThreads(std::thread::hardware_concurrency()); // Nullptr evaluation is only allowed by the derived classes. if (!m_evaluationFunction) { if(initializeCombined) InvalidArgument("Eval function is not allowed to be null."); return; } if (!m_evaluationFunction->Output().DynamicAxes().empty()) { m_aggregatedEvaluationFunction = ReduceSum(m_evaluationFunction, Axis::AllAxes(), L"aggregateEvalMetric"); m_testSampleCountVar = m_evaluationFunction; } else { m_aggregatedEvaluationFunction = m_evaluationFunction; m_testSampleCountVar = m_evaluationFunction->RootFunction()->Inputs()[0]; } if(initializeCombined) m_combinedEvalFunction = Combine(GetCombinedEvalFunctionArgs()); }
Trainer::Trainer(const FunctionPtr& model, const FunctionPtr& lossFunction, const FunctionPtr& evaluationFunction, const std::vector<LearnerPtr>& parameterLearners, const std::vector<ProgressWriterPtr>& progressWriters) : Evaluator(evaluationFunction, progressWriters, false), m_model(model), m_lossFunction(lossFunction), m_parameterLearners(std::make_shared<Learners>(parameterLearners)), m_prevMinibatchNumSamples(0), m_distributed(false), m_aggregatedTrainingLossValue(std::make_shared<Accumulator>()), m_aggregatedTrainingEvalCriterionValue(), m_prevDistributedTotalNumSamples(0) { std::vector<Variable> combinedFunctionArgs; if (m_model) // model is optional, since it may not be adding any information on top of lossFunction combinedFunctionArgs = m_model->Outputs(); combinedFunctionArgs.push_back(m_lossFunction); if (m_lossFunction->Output().GetDataType() == DataType::Float16) fprintf(stderr, "WARNING: using Float16 for loss function may cause overflow, please cast to float"); if (!m_lossFunction->Output().DynamicAxes().empty()) { m_aggregatedLossFunction = ReduceSum(lossFunction, Axis::AllAxes(), L"aggregateLoss"); combinedFunctionArgs.push_back(m_aggregatedLossFunction); m_trainingSampleCountVar = m_lossFunction; } else { m_aggregatedLossFunction = m_lossFunction; std::function<std::pair<Variable, bool>(const FunctionPtr& root)> FindTrainingSampleCountVar; FindTrainingSampleCountVar = [&FindTrainingSampleCountVar](const FunctionPtr& root) -> std::pair<Variable, bool> { const auto& outputs = root->Outputs(); auto firstOutputWithDynamicAxes = std::find_if(outputs.begin(), outputs.end(), [](const Variable& var) { return !var.DynamicAxes().empty(); }); if (firstOutputWithDynamicAxes != outputs.end()) return std::make_pair(*firstOutputWithDynamicAxes, true); const auto& inputs = root->Inputs(); for (const auto& input : inputs) { if (!input.DynamicAxes().empty()) return std::make_pair(input, true); if (input.IsOutput()) { auto retVal = FindTrainingSampleCountVar(input.Owner()); if (retVal.second) return retVal; } } return std::make_pair(Variable(), false); }; auto findTrainingSampleCountVarRetVal = FindTrainingSampleCountVar(m_lossFunction->RootFunction()); if (!findTrainingSampleCountVarRetVal.second) InvalidArgument("Trainer: Failed to find a variable underlying the graph rooted at specified loss function '%S', from which the training sample count can be determined.", m_lossFunction->RootFunction()->AsString().c_str()); m_trainingSampleCountVar = findTrainingSampleCountVarRetVal.first; if (GetTraceLevel() >= TraceLevel::Info) fprintf(stderr, "Info: Trainer loss Function '%S' output does not have a batch axis; the first Variable '%S' with a batch axis found in the graph underlying the scalar " "loss Function will be used to determine minibatch training sample count.\n", m_lossFunction->AsString().c_str(), m_trainingSampleCountVar.AsString().c_str()); if (std::find(combinedFunctionArgs.begin(), combinedFunctionArgs.end(), m_trainingSampleCountVar) == combinedFunctionArgs.end()) combinedFunctionArgs.push_back(m_trainingSampleCountVar); } if (evaluationFunction) { auto evalArgs = GetCombinedEvalFunctionArgs(); combinedFunctionArgs.insert(combinedFunctionArgs.end(), evalArgs.begin(), evalArgs.end()); m_aggregatedTrainingEvalCriterionValue = std::make_shared<Accumulator>(); } // create a default eval value in case there's no criterion m_prevMinibatchAggregateEvalCriterionValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(0, m_aggregatedLossFunction->Output().GetDataType(), NDShape{}, DeviceDescriptor::CPUDevice())); m_combinedTrainingFunction = Combine(combinedFunctionArgs); SetCombinedEvalFunction(m_combinedTrainingFunction); auto modelParameters = m_combinedTrainingFunction->Parameters(); m_learnerParameters = m_parameterLearners->GetParameters(); std::unordered_set<Parameter> modelParametersSet(modelParameters.begin(), modelParameters.end()); std::unordered_set<Parameter> learnerParametersNotPartOfModel; for (const auto& learnerParameter : m_learnerParameters) { if (modelParametersSet.find(learnerParameter) == modelParametersSet.end()) learnerParametersNotPartOfModel.insert(learnerParameter); } for (const auto& modelParameter : modelParametersSet) { if (m_learnerParameters.find(modelParameter) == m_learnerParameters.end()) m_modelParametersNotCoveredByLearners.insert(modelParameter); } if (!learnerParametersNotPartOfModel.empty()) InvalidArgument("Trainer ctor: %d of the learner parameters '%S' are not part of the model specified", (int)learnerParametersNotPartOfModel.size(), NamedListString(learnerParametersNotPartOfModel).c_str()); if (!m_modelParametersNotCoveredByLearners.empty()) fprintf(stderr, "[Note:] Trainer ctor: %d of the model parameters are not covered by any of the specified Learners; these parameters will not be learned\n", (int)m_modelParametersNotCoveredByLearners.size()); m_distributed = m_parameterLearners->IsDistributed(); if (m_distributed) Evaluator::SetCommunicator(dynamic_cast<DistributedLearner*>(m_parameterLearners->ParameterLearners()[0].get())->GetCommunicator()); for (auto& learner : m_parameterLearners->ParameterLearners()) { learner->AddProgressWriters(progressWriters); } }