Dictionary Trainer::RestoreFromCheckpoint(const std::wstring& modelFilePath) { // Restore the model's parameters m_combinedTrainingFunction->Restore(modelFilePath); Dictionary checkpoint = Dictionary::Load(GetTrainerStateCheckpointFilePath(modelFilePath)); size_t version = 0; if (checkpoint.Contains(versionPropertyName)) version = checkpoint[versionPropertyName].Value<size_t>(); auto learnerState = checkpoint[learnersPropertyName].Value<std::vector<DictionaryValue>>(); auto externalState = checkpoint[externalStatePropertyName].Value<Dictionary>(); m_parameterLearners->RestoreFromCheckpoint(learnerState); if (!m_distributed) { return externalState; } // this ensures that nobody will start writing to the model/checkpoint files, until // everybody is done reading them. DistributedCommunicatorPtr communicator = MPICommunicator(); communicator->Barrier(); auto mainWorkerId = std::to_wstring(0); auto localWorkerId = std::to_wstring(communicator->CurrentWorker().m_globalRank); // before version 1, there was no distributed state per se. Instead, the external state // contained a dictionary of worker-specific external states. if (version == 0) { auto key = externalState.Contains(localWorkerId) ? localWorkerId : mainWorkerId; return externalState[key].Value<Dictionary>(); } Dictionary distributedState = checkpoint[distributedStatePropertyName].Value<Dictionary>(); if (communicator->CurrentWorker().IsMain() || !distributedState.Contains(localWorkerId)) { return externalState; } // the checkpoint contains internal state for this worker. Dictionary localState = distributedState[localWorkerId].Value<Dictionary>(); auto internalState = localState[internalWorkerStateKey].Value<Dictionary>(); auto compositeFunction = std::dynamic_pointer_cast<CompositeFunction>(m_combinedTrainingFunction); if (compositeFunction == nullptr) RuntimeError("Combined training function is not a CompositeFunction."); // this assumes the compositeFunction (restored form a checkpoint made by the main node) and // the internal worker state both have identical UIDs. compositeFunction->SetInternalState(internalState); return localState[externalWorkerStateKey].Value<Dictionary>(); }
// dictionary load/contains/find test void Test::Test9() { const size_t dict_size = 10; ofstream f("dict"); for (size_t i = 0; i < dict_size; ++i) f << i << endl; f.close(); Dictionary dict; string msg; bool res = dict.Load("dict", 1, msg); for (size_t i = 0; res && i < dict_size; ++i) { string str = to_string(i); res = dict.Contains(str); Dictionary::DictionaryData::iterator di; res = res && dict.Find(str, di); res = res && di->first.compare(str) == 0; res = res && di->second == false; } printf(res ? "Test9:\tpassed\r\n" : "Test9:\tfailed\r\n"); remove("dict"); }
/*static*/ FunctionPtr UDFUtils::Deserialize(const Dictionary& dict, const unordered_map<std::wstring, Variable>& uidToVariableMap, const DeviceDescriptor& device) { static const vector<std::wstring> s_requiredDictionaryKeys = { typeKey, uidKey, inputsKey, userDefinedStateKey }; ValidateDictionary<PrimitiveFunction>(dict, s_requiredDictionaryKeys, s_userDefinedFunctionTypeValue, s_serializationVersion); const auto& uid = dict[uidKey].Value<std::wstring>(); std::wstring name = L""; if (dict.Contains(nameKey)) name = dict[nameKey].Value<std::wstring>(); auto inputs = GetInputVariables(dict, uidToVariableMap, s_serializationVersion); auto state = dict[userDefinedStateKey].Value<Dictionary>(); FunctionPtr udf; if (IsNativeUDF(dict)) { udf = Function::DeserializeNativeImpl(inputs, name, state); } else if (s_SWIGCallbackWrapper != nullptr) { // If we're being called from SWIG, the actual deserializer should be registered by // the target language CNTK implementation (i.e., cntk_py for Python) udf = s_SWIGCallbackWrapper->operator()(inputs, name, state); } if (udf == nullptr) { RuntimeError("Unable to reconstruct a user-defined function (name = %S, uid = %S). " "Please make sure to specify a valid UDF deserializer.", name.c_str(), uid.c_str()); } // Restore the original uid, which other functions in the graph depend on // (their inputs refer to the uids of this UDF outputs, which are generated base on the uid of this UDF). udf->m_uid = uid; return udf; }
/*static*/ bool UDFUtils::IsNativeUDF(const Dictionary& dict) { assert(IsUDF(dict)); return (dict.Contains(nativeUDFKey) && dict[nativeUDFKey].Value<bool>() == true); }
/*static*/ bool UDFUtils::IsUDF(const Dictionary& dict) { return (dict.Contains(typeKey) && dict[typeKey].Value<std::wstring>() == s_userDefinedFunctionTypeValue); }
/*static*/ Variable Variable::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device) { static const vector<std::wstring> s_requiredDictionaryKeys = { typeKey, uidKey, kindKey, dataTypeKey, dynamicAxisKey, isSparseKey, needsGradientKey, shapeKey }; size_t version = ValidateDictionary<Variable>(dict, s_requiredDictionaryKeys, s_variableTypeValue, s_serializationVersion); const auto& uid = dict[uidKey].Value<std::wstring>(); VariableKind kind = VariableKind(dict[kindKey].Value<std::size_t>()); if (kind != VariableKind::Constant && kind != VariableKind::Input && kind != VariableKind::Parameter && kind != VariableKind::Placeholder) { LogicError("Unexpected variable '%ls':'%u' (%s).", kindKey.c_str(), static_cast<std::underlying_type<VariableKind>::type>(kind), GetVersionsString<Variable>(s_serializationVersion, version).c_str()); } DataType dataType = DataType(dict[dataTypeKey].Value<std::size_t>()); if (dataType != DataType::Unknown && dataType != DataType::Float && dataType != DataType::Double) { LogicError("Unexpected variable '%ls':'%u' (%s).", dataTypeKey.c_str(), static_cast<std::underlying_type<DataType>::type>(dataType), GetVersionsString<Variable>(s_serializationVersion, version).c_str()); } const vector<DictionaryValue>& dictionaryValueVector = dict[dynamicAxisKey].Value<vector<DictionaryValue>>(); vector<Axis> dynamicAxis; dynamicAxis.reserve(dictionaryValueVector.size()); for (const auto& dictionaryValue : dictionaryValueVector) { dynamicAxis.push_back(dictionaryValue.Value<Axis>()); } bool isSparse = dict[isSparseKey].Value<bool>(); std::wstring name = L""; if (dict.Contains(nameKey)) name = dict[nameKey].Value<std::wstring>(); bool needsGradient = dict[needsGradientKey].Value<bool>(); const auto& shape = dict[shapeKey].Value<NDShape>(); if (kind == VariableKind::Constant || kind == VariableKind::Parameter) { auto& value = dict[valueKey].Value<NDArrayView>(); // TODO: this copying here is redundant, value should be moved from the dictionary to the variable. // Also, the correct device should be used upfront when deserializing NDArrayView. Variable var(shape, kind, dataType, value.DeepClone(device, kind == VariableKind::Constant), needsGradient, dynamicAxis, isSparse, name, uid); if (var.IsParameter()) return Parameter(var); else return Constant(var); } return Variable(shape, kind, dataType, nullptr, needsGradient, dynamicAxis, isSparse, name, uid); }