Ejemplo n.º 1
0
    Dictionary Trainer::RestoreFromCheckpoint(const std::wstring& modelFilePath)
    {
        // Restore the model's parameters
        m_combinedTrainingFunction->Restore(modelFilePath);

        Dictionary checkpoint = Dictionary::Load(GetTrainerStateCheckpointFilePath(modelFilePath));

        size_t version = 0;

        if (checkpoint.Contains(versionPropertyName))
            version = checkpoint[versionPropertyName].Value<size_t>();
        
        auto learnerState = checkpoint[learnersPropertyName].Value<std::vector<DictionaryValue>>();
        auto externalState = checkpoint[externalStatePropertyName].Value<Dictionary>();

        m_parameterLearners->RestoreFromCheckpoint(learnerState);

        if (!m_distributed)
        {
            return externalState;
        }

        // this ensures that nobody will start writing to the model/checkpoint files, until
        // everybody is done reading them.
        DistributedCommunicatorPtr communicator = MPICommunicator();
        communicator->Barrier();

        auto mainWorkerId = std::to_wstring(0);
        auto localWorkerId = std::to_wstring(communicator->CurrentWorker().m_globalRank);

        // before version 1, there was no distributed state per se. Instead, the external state
        // contained a dictionary of worker-specific external states.
        if (version == 0)
        {
            auto key = externalState.Contains(localWorkerId) ? localWorkerId : mainWorkerId;
            return externalState[key].Value<Dictionary>();
        }

        Dictionary distributedState = checkpoint[distributedStatePropertyName].Value<Dictionary>();

        if (communicator->CurrentWorker().IsMain() || !distributedState.Contains(localWorkerId))
        {
            return externalState;
        }
        
        // the checkpoint contains internal state for this worker.
        Dictionary localState = distributedState[localWorkerId].Value<Dictionary>();

        auto internalState = localState[internalWorkerStateKey].Value<Dictionary>();
        auto compositeFunction = std::dynamic_pointer_cast<CompositeFunction>(m_combinedTrainingFunction);
        if (compositeFunction == nullptr)
            RuntimeError("Combined training function is not a CompositeFunction.");
            
        // this assumes the compositeFunction (restored form a checkpoint made by the main node) and 
        // the internal worker state both have identical UIDs.
        compositeFunction->SetInternalState(internalState);
        
        return localState[externalWorkerStateKey].Value<Dictionary>();
    }
Ejemplo n.º 2
0
// dictionary load/contains/find test
void Test::Test9()
{
    const size_t dict_size = 10;

    ofstream f("dict");
    for (size_t i = 0; i < dict_size; ++i)
        f << i << endl;
    f.close();

    Dictionary dict;
    string msg;
    bool res = dict.Load("dict", 1, msg);

    for (size_t i = 0; res && i < dict_size; ++i)
    {
        string str = to_string(i);
        res = dict.Contains(str);

        Dictionary::DictionaryData::iterator di;
        res = res && dict.Find(str, di);
        res = res && di->first.compare(str) == 0;
        res = res && di->second == false;
    }

    printf(res ? "Test9:\tpassed\r\n" : "Test9:\tfailed\r\n");

    remove("dict");
}
Ejemplo n.º 3
0
    /*static*/ FunctionPtr UDFUtils::Deserialize(const Dictionary& dict,
        const unordered_map<std::wstring, Variable>& uidToVariableMap,
        const DeviceDescriptor& device)
    {
        static const vector<std::wstring> s_requiredDictionaryKeys = { typeKey, uidKey, inputsKey, userDefinedStateKey };
        ValidateDictionary<PrimitiveFunction>(dict, s_requiredDictionaryKeys, s_userDefinedFunctionTypeValue, s_serializationVersion);

        const auto& uid = dict[uidKey].Value<std::wstring>();
        std::wstring name = L"";
        if (dict.Contains(nameKey))
            name = dict[nameKey].Value<std::wstring>();

        auto inputs = GetInputVariables(dict, uidToVariableMap, s_serializationVersion);

        auto state = dict[userDefinedStateKey].Value<Dictionary>();

        FunctionPtr udf;

        if (IsNativeUDF(dict))
        {
            udf = Function::DeserializeNativeImpl(inputs, name, state);
        }
        else if (s_SWIGCallbackWrapper != nullptr)
        {
            // If we're being called from SWIG, the actual deserializer should be registered by
            // the target language CNTK implementation (i.e., cntk_py for Python)
            udf = s_SWIGCallbackWrapper->operator()(inputs, name, state);
        }

        if (udf == nullptr)
        {
            RuntimeError("Unable to reconstruct a user-defined function (name = %S, uid = %S). "
                "Please make sure to specify a valid UDF deserializer.", name.c_str(), uid.c_str());
        }

        // Restore the original uid, which other functions in the graph depend on
        // (their inputs refer to the uids of this UDF outputs, which are generated base on the uid of this UDF).
        udf->m_uid = uid;

        return udf;
    }
Ejemplo n.º 4
0
 /*static*/ bool UDFUtils::IsNativeUDF(const Dictionary& dict)
 {
     assert(IsUDF(dict));
     return (dict.Contains(nativeUDFKey) && dict[nativeUDFKey].Value<bool>() == true);
 }
Ejemplo n.º 5
0
 /*static*/ bool UDFUtils::IsUDF(const Dictionary& dict)
 {
     return (dict.Contains(typeKey) && dict[typeKey].Value<std::wstring>() == s_userDefinedFunctionTypeValue);
 }
Ejemplo n.º 6
0
    /*static*/ Variable Variable::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device)
    {
        static const vector<std::wstring> s_requiredDictionaryKeys = { typeKey, uidKey, kindKey, dataTypeKey, dynamicAxisKey, isSparseKey, needsGradientKey, shapeKey };

        size_t version = ValidateDictionary<Variable>(dict, s_requiredDictionaryKeys, s_variableTypeValue, s_serializationVersion);

        const auto& uid = dict[uidKey].Value<std::wstring>();

        VariableKind kind = VariableKind(dict[kindKey].Value<std::size_t>());
        if (kind != VariableKind::Constant &&
            kind != VariableKind::Input &&
            kind != VariableKind::Parameter &&
            kind != VariableKind::Placeholder)
        {
            LogicError("Unexpected variable '%ls':'%u' (%s).",
                       kindKey.c_str(),
                       static_cast<std::underlying_type<VariableKind>::type>(kind),
                       GetVersionsString<Variable>(s_serializationVersion, version).c_str());
        }
        
        DataType dataType = DataType(dict[dataTypeKey].Value<std::size_t>());
        if (dataType != DataType::Unknown &&
            dataType != DataType::Float &&
            dataType != DataType::Double)
        {
            LogicError("Unexpected variable '%ls':'%u' (%s).", 
                       dataTypeKey.c_str(), 
                       static_cast<std::underlying_type<DataType>::type>(dataType),
                       GetVersionsString<Variable>(s_serializationVersion, version).c_str());
        }
        
        const vector<DictionaryValue>& dictionaryValueVector = dict[dynamicAxisKey].Value<vector<DictionaryValue>>();
        vector<Axis> dynamicAxis;
        dynamicAxis.reserve(dictionaryValueVector.size());
        for (const auto& dictionaryValue : dictionaryValueVector)
        {
            dynamicAxis.push_back(dictionaryValue.Value<Axis>());
        }

        bool isSparse = dict[isSparseKey].Value<bool>();
        std::wstring name = L"";
        if (dict.Contains(nameKey))
            name = dict[nameKey].Value<std::wstring>();
        bool needsGradient = dict[needsGradientKey].Value<bool>();
        const auto& shape = dict[shapeKey].Value<NDShape>();

        if (kind == VariableKind::Constant || kind == VariableKind::Parameter)
        {
            auto& value = dict[valueKey].Value<NDArrayView>();

            // TODO: this copying here is redundant, value should be moved from the dictionary to the variable.
            // Also, the correct device should be used upfront when deserializing NDArrayView.
            Variable var(shape, kind, dataType, value.DeepClone(device, kind == VariableKind::Constant), needsGradient, dynamicAxis, isSparse, name, uid);
            if (var.IsParameter())
                return Parameter(var);
            else
                return Constant(var);
        }

        return Variable(shape, kind, dataType, nullptr, needsGradient, dynamicAxis, isSparse, name, uid);
    }