Ejemplo n.º 1
0
    void TCatboostOptions::Validate() const {
        SystemOptions.Get().Validate();
        BoostingOptions.Get().Validate();
        ObliviousTreeOptions.Get().Validate();

        ELossFunction lossFunction = LossFunctionDescription->GetLossFunction();
        {
            const ui32 classesCount = DataProcessingOptions->ClassesCount;
            if (classesCount != 0 ) {
                CB_ENSURE(IsMultiClassError(lossFunction), "classes_count parameter takes effect only with MultiClass/MultiClassOneVsAll loss functions");
                CB_ENSURE(classesCount > 1, "classes-count should be at least 2");
            }
            const auto& classWeights = DataProcessingOptions->ClassWeights.Get();
            if (!classWeights.empty()) {
                CB_ENSURE(lossFunction == ELossFunction::Logloss || IsMultiClassError(lossFunction),
                          "class weights takes effect only with Logloss, MultiClass and MultiClassOneVsAll loss functions");
                CB_ENSURE(IsMultiClassError(lossFunction) || (classWeights.size() == 2),
                          "if loss-function is Logloss, then class weights should be given for 0 and 1 classes");
                CB_ENSURE(classesCount == 0 || classesCount == classWeights.size(), "class weights should be specified for each class in range 0, ... , classes_count - 1");
            }
        }

        ELeavesEstimation leavesEstimation = ObliviousTreeOptions->LeavesEstimationMethod;
        if (lossFunction == ELossFunction::Quantile ||
            lossFunction == ELossFunction::MAE ||
            lossFunction == ELossFunction::LogLinQuantile ||
            lossFunction == ELossFunction::MAPE)
        {
            CB_ENSURE(leavesEstimation != ELeavesEstimation::Newton,
                      "Newton leave estimation method is not supported for " << lossFunction << " loss function");
            CB_ENSURE(ObliviousTreeOptions->LeavesEstimationIterations == 1U,
                      "gradient_iterations should equals 1 for this mode");
        }

        if (GetTaskType() == ETaskType::CPU) {
            CB_ENSURE(!(IsQuerywiseError(lossFunction) && leavesEstimation == ELeavesEstimation::Newton),
                      "This leaf estimation method is not supported for querywise error for CPU learning");

            CB_ENSURE(!(IsPairwiseError(lossFunction) && leavesEstimation == ELeavesEstimation::Newton),
                      "This leaf estimation method is not supported for pairwise error");
        }


        ValidateCtrs(CatFeatureParams->SimpleCtrs, lossFunction, false);
        for (const auto& perFeatureCtr : CatFeatureParams->PerFeatureCtrs.Get()) {
            ValidateCtrs(perFeatureCtr.second, lossFunction, false);
        }
        ValidateCtrs(CatFeatureParams->CombinationCtrs, lossFunction, true);
    }
void TLearnContext::InitContext(const TDataset& learnData, const TDatasetPtrs& testDataPtrs) {
    LearnProgress.PoolCheckSum = CalcFeaturesCheckSum(learnData.AllFeatures);
    for (const TDataset* testData : testDataPtrs) {
        LearnProgress.PoolCheckSum += CalcFeaturesCheckSum(testData->AllFeatures);
    }

    auto lossFunction = Params.LossFunctionDescription->GetLossFunction();
    int foldCount = Max<ui32>(Params.BoostingOptions->PermutationCount - 1, 1);
    const bool noCtrs = IsCategoricalFeaturesEmpty(learnData.AllFeatures);
    if (Params.BoostingOptions->BoostingType == EBoostingType::Plain && noCtrs) {
        foldCount = 1;
    }
    LearnProgress.Folds.reserve(foldCount);
    UpdateCtrsTargetBordersOption(lossFunction, LearnProgress.ApproxDimension, &Params.CatFeatureParams.Get());

    CtrsHelper.InitCtrHelper(Params.CatFeatureParams,
                             Layout,
                             learnData.Target,
                             lossFunction,
                             ObjectiveDescriptor);

    //Todo(noxoomo): check and init
    const auto& boostingOptions = Params.BoostingOptions.Get();
    ui32 foldPermutationBlockSize = boostingOptions.PermutationBlockSize;
    if (foldPermutationBlockSize == FoldPermutationBlockSizeNotSet) {
        foldPermutationBlockSize = DefaultFoldPermutationBlockSize(learnData.GetSampleCount());
    }
    if (IsPlainMode(Params.BoostingOptions->BoostingType) && noCtrs) {
        foldPermutationBlockSize = learnData.GetSampleCount();
    }
    const auto storeExpApproxes = IsStoreExpApprox(Params.LossFunctionDescription->GetLossFunction());
    const bool hasPairwiseWeights = IsPairwiseError(Params.LossFunctionDescription->GetLossFunction());

    if (IsPlainMode(Params.BoostingOptions->BoostingType)) {
        for (int foldIdx = 0; foldIdx < foldCount; ++foldIdx) {
            LearnProgress.Folds.emplace_back(
                BuildPlainFold(
                    learnData,
                    CtrsHelper.GetTargetClassifiers(),
                    foldIdx != 0,
                    (Params.SystemOptions->IsSingleHost() ? foldPermutationBlockSize : learnData.GetSampleCount()),
                    LearnProgress.ApproxDimension,
                    storeExpApproxes,
                    hasPairwiseWeights,
                    Rand
                )
            );
        }
    } else {
        for (int foldIdx = 0; foldIdx < foldCount; ++foldIdx) {
            LearnProgress.Folds.emplace_back(
                BuildDynamicFold(
                    learnData,
                    CtrsHelper.GetTargetClassifiers(),
                    foldIdx != 0,
                    foldPermutationBlockSize,
                    LearnProgress.ApproxDimension,
                    boostingOptions.FoldLenMultiplier,
                    storeExpApproxes,
                    hasPairwiseWeights,
                    Rand
                )
            );
        }
    }

    LearnProgress.AveragingFold = BuildPlainFold(
        learnData,
        CtrsHelper.GetTargetClassifiers(),
        !(Params.DataProcessingOptions->HasTimeFlag),
        /*permuteBlockSize=*/ (Params.SystemOptions->IsSingleHost() ? foldPermutationBlockSize : learnData.GetSampleCount()),
        LearnProgress.ApproxDimension,
        storeExpApproxes,
        hasPairwiseWeights,
        Rand
    );

    LearnProgress.AvrgApprox.resize(LearnProgress.ApproxDimension, TVector<double>(learnData.GetSampleCount()));
    if (!learnData.Baseline.empty()) {
        LearnProgress.AvrgApprox = learnData.Baseline;
    }
    ResizeRank2(testDataPtrs.size(), LearnProgress.ApproxDimension, LearnProgress.TestApprox);
    for (size_t testIdx = 0; testIdx < testDataPtrs.size(); ++testIdx) {
        const auto* testData = testDataPtrs[testIdx];
        if (testData == nullptr || testData->GetSampleCount() == 0) {
            continue;
        }
        if (testData->Baseline.empty()) {
            for (auto& approxDim : LearnProgress.TestApprox[testIdx]) {
                approxDim.resize(testData->GetSampleCount());
            }
        } else {
            LearnProgress.TestApprox[testIdx] = testData->Baseline;
        }
    }
}