示例#1
0
    Evaluator::Evaluator(
        const FunctionPtr& evaluationFunction,
        const std::vector<ProgressWriterPtr>& progressWriters,
        bool initializeCombined)
        :  m_evaluationFunction(evaluationFunction),
           m_aggregatedTestEvalCriterionValue(std::make_shared<Accumulator>()),
           m_progressWriters(progressWriters.begin(), progressWriters.end())
    {
        // By default we set the number of threads to hardware concurrency.
        if (!Internal::MaxNumCPUThreadsSet())
            SetMaxNumCPUThreads(std::thread::hardware_concurrency());

        // Nullptr evaluation is only allowed by the derived classes.
        if (!m_evaluationFunction)
        {
            if(initializeCombined)
                InvalidArgument("Eval function is not allowed to be null.");
            return;
        }

        if (!m_evaluationFunction->Output().DynamicAxes().empty())
        {
            m_aggregatedEvaluationFunction = ReduceSum(m_evaluationFunction, Axis::AllAxes(), L"aggregateEvalMetric");
            m_testSampleCountVar = m_evaluationFunction;
        }
        else
        {
            m_aggregatedEvaluationFunction = m_evaluationFunction;
            m_testSampleCountVar = m_evaluationFunction->RootFunction()->Inputs()[0];
        }

        if(initializeCombined)
            m_combinedEvalFunction = Combine(GetCombinedEvalFunctionArgs());
    }
示例#2
0
    Trainer::Trainer(const FunctionPtr& model, const FunctionPtr& lossFunction, const FunctionPtr& evaluationFunction,
                     const std::vector<LearnerPtr>& parameterLearners,
                    const std::vector<ProgressWriterPtr>& progressWriters) 
        : Evaluator(evaluationFunction, progressWriters, false),
          m_model(model),
          m_lossFunction(lossFunction),
          m_parameterLearners(std::make_shared<Learners>(parameterLearners)),
          m_prevMinibatchNumSamples(0),
          m_distributed(false),
          m_aggregatedTrainingLossValue(std::make_shared<Accumulator>()),
          m_aggregatedTrainingEvalCriterionValue(),
          m_prevDistributedTotalNumSamples(0)
    {
        std::vector<Variable> combinedFunctionArgs;
        if (m_model) // model is optional, since it may not be adding any information on top of lossFunction
            combinedFunctionArgs = m_model->Outputs();

        combinedFunctionArgs.push_back(m_lossFunction);

        if (m_lossFunction->Output().GetDataType() == DataType::Float16)
            fprintf(stderr, "WARNING: using Float16 for loss function may cause overflow, please cast to float");

        if (!m_lossFunction->Output().DynamicAxes().empty())
        {
            m_aggregatedLossFunction = ReduceSum(lossFunction, Axis::AllAxes(), L"aggregateLoss");
            combinedFunctionArgs.push_back(m_aggregatedLossFunction);
            m_trainingSampleCountVar = m_lossFunction;
        }
        else
        {
            m_aggregatedLossFunction = m_lossFunction;

            std::function<std::pair<Variable, bool>(const FunctionPtr& root)> FindTrainingSampleCountVar;
            FindTrainingSampleCountVar = [&FindTrainingSampleCountVar](const FunctionPtr& root) -> std::pair<Variable, bool> {
                const auto& outputs = root->Outputs();
                auto firstOutputWithDynamicAxes = std::find_if(outputs.begin(), outputs.end(), [](const Variable& var) { return !var.DynamicAxes().empty(); });
                if (firstOutputWithDynamicAxes != outputs.end())
                    return std::make_pair(*firstOutputWithDynamicAxes, true);

                const auto& inputs = root->Inputs();
                for (const auto& input : inputs)
                {
                    if (!input.DynamicAxes().empty())
                        return std::make_pair(input, true);

                    if (input.IsOutput())
                    {
                        auto retVal = FindTrainingSampleCountVar(input.Owner());
                        if (retVal.second)
                            return retVal;
                    }
                }

                return std::make_pair(Variable(), false);
            };

            auto findTrainingSampleCountVarRetVal = FindTrainingSampleCountVar(m_lossFunction->RootFunction());
            if (!findTrainingSampleCountVarRetVal.second)
                InvalidArgument("Trainer: Failed to find a variable underlying the graph rooted at specified loss function '%S', from which the training sample count can be determined.", m_lossFunction->RootFunction()->AsString().c_str());

            m_trainingSampleCountVar = findTrainingSampleCountVarRetVal.first;
            if (GetTraceLevel() >= TraceLevel::Info)
                fprintf(stderr, "Info: Trainer loss Function '%S' output does not have a batch axis; the first Variable '%S' with a batch axis found in the graph underlying the scalar "
                                "loss Function will be used to determine minibatch training sample count.\n", m_lossFunction->AsString().c_str(), m_trainingSampleCountVar.AsString().c_str());

            if (std::find(combinedFunctionArgs.begin(), combinedFunctionArgs.end(), m_trainingSampleCountVar) == combinedFunctionArgs.end())
                combinedFunctionArgs.push_back(m_trainingSampleCountVar);
        }

        if (evaluationFunction)
        {
            auto evalArgs = GetCombinedEvalFunctionArgs();
            combinedFunctionArgs.insert(combinedFunctionArgs.end(), evalArgs.begin(), evalArgs.end());

            m_aggregatedTrainingEvalCriterionValue = std::make_shared<Accumulator>();
        }

        // create a default eval value in case there's no criterion
        m_prevMinibatchAggregateEvalCriterionValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(0, m_aggregatedLossFunction->Output().GetDataType(), NDShape{}, DeviceDescriptor::CPUDevice()));

        m_combinedTrainingFunction = Combine(combinedFunctionArgs);
        SetCombinedEvalFunction(m_combinedTrainingFunction);

        auto modelParameters = m_combinedTrainingFunction->Parameters();
        m_learnerParameters = m_parameterLearners->GetParameters();
        std::unordered_set<Parameter> modelParametersSet(modelParameters.begin(), modelParameters.end());
        std::unordered_set<Parameter> learnerParametersNotPartOfModel;
        for (const auto& learnerParameter : m_learnerParameters)
        {
            if (modelParametersSet.find(learnerParameter) == modelParametersSet.end())
                learnerParametersNotPartOfModel.insert(learnerParameter);
        }

        for (const auto& modelParameter : modelParametersSet)
        {
            if (m_learnerParameters.find(modelParameter) == m_learnerParameters.end())
                m_modelParametersNotCoveredByLearners.insert(modelParameter);
        }

        if (!learnerParametersNotPartOfModel.empty())
            InvalidArgument("Trainer ctor: %d of the learner parameters '%S' are not part of the model specified", 
                            (int)learnerParametersNotPartOfModel.size(), NamedListString(learnerParametersNotPartOfModel).c_str());

        if (!m_modelParametersNotCoveredByLearners.empty())
            fprintf(stderr, "[Note:] Trainer ctor: %d of the model parameters are not covered by any of the specified Learners; these parameters will not be learned\n", (int)m_modelParametersNotCoveredByLearners.size());

        m_distributed = m_parameterLearners->IsDistributed();

        if (m_distributed)
            Evaluator::SetCommunicator(dynamic_cast<DistributedLearner*>(m_parameterLearners->ParameterLearners()[0].get())->GetCommunicator());

        for (auto& learner : m_parameterLearners->ParameterLearners())
        {
            learner->AddProgressWriters(progressWriters);
        }
    }