void Network::setDefaultPhase_(Region* region) { UInt32 newphase = phaseInfo_.size(); std::set<UInt32> phases; phases.insert(newphase); setPhases_(region, phases); }
void Network::setPhases(const std::string& name, std::set<UInt32>& phases) { if (! regions_.contains(name)) NTA_THROW << "setPhases -- no region exists with name '" << name << "'"; Region *r = regions_.getByName(name); setPhases_(r, phases); }
void Network::loadFromBundle(const std::string& name) { if (! StringUtils::endsWith(name, ".nta")) NTA_THROW << "loadFromBundle: bundle extension must be \".nta\""; std::string fullPath = Path::normalize(Path::makeAbsolute(name)); if (! Path::exists(fullPath)) NTA_THROW << "Path " << fullPath << " does not exist"; std::string networkStructureFilename = Path::join(fullPath, "network.yaml"); std::ifstream f(networkStructureFilename.c_str()); YAML::Parser parser(f); YAML::Node doc; bool success = parser.GetNextDocument(doc); if (!success) NTA_THROW << "Unable to find YAML document in network structure file " << networkStructureFilename; if (doc.Type() != YAML::NodeType::Map) NTA_THROW << "Invalid network structure file -- does not contain a map"; // Should contain Version, Regions, Links if (doc.size() != 3) NTA_THROW << "Invalid network structure file -- contains " << doc.size() << " elements"; // Extra version const YAML::Node *node = doc.FindValue("Version"); if (node == NULL) NTA_THROW << "Invalid network structure file -- no version"; int version; *node >> version; if (version != 2) NTA_THROW << "Invalid network structure file -- only version 2 supported"; // Regions const YAML::Node *regions = doc.FindValue("Regions"); if (regions == NULL) NTA_THROW << "Invalid network structure file -- no regions"; if (regions->Type() != YAML::NodeType::Sequence) NTA_THROW << "Invalid network structure file -- regions element is not a list"; for (YAML::Iterator region = regions->begin(); region != regions->end(); region++) { // Each region is a map -- extract the 5 values in the map if ((*region).Type() != YAML::NodeType::Map) NTA_THROW << "Invalid network structure file -- bad region (not a map)"; if ((*region).size() != 5) NTA_THROW << "Invalid network structure file -- bad region (wrong size)"; // 1. name node = (*region).FindValue("name"); if (node == NULL) NTA_THROW << "Invalid network structure file -- region has no name"; std::string name; *node >> name; // 2. nodeType node = (*region).FindValue("nodeType"); if (node == NULL) NTA_THROW << "Invalid network structure file -- region " << name << " has no node type"; std::string nodeType; *node >> nodeType; // 3. dimensions node = (*region).FindValue("dimensions"); if (node == NULL) NTA_THROW << "Invalid network structure file -- region " << name << " has no dimensions"; if ((*node).Type() != YAML::NodeType::Sequence) NTA_THROW << "Invalid network structure file -- region " << name << " dimensions specified incorrectly"; Dimensions dimensions; for (YAML::Iterator valiter = (*node).begin(); valiter != (*node).end(); valiter++) { size_t val; (*valiter) >> val; dimensions.push_back(val); } // 4. phases node = (*region).FindValue("phases"); if (node == NULL) NTA_THROW << "Invalid network structure file -- region" << name << "has no phases"; if ((*node).Type() != YAML::NodeType::Sequence) NTA_THROW << "Invalid network structure file -- region " << name << " phases specified incorrectly"; std::set<UInt32> phases; for (YAML::Iterator valiter = (*node).begin(); valiter != (*node).end(); valiter++) { UInt32 val; (*valiter) >> val; phases.insert(val); } // 5. label node = (*region).FindValue("label"); if (node == NULL) NTA_THROW << "Invalid network structure file -- region" << name << "has no label"; std::string label; *node >> label; Region *r = addRegionFromBundle(name, nodeType, dimensions, fullPath, label); setPhases_(r, phases); } const YAML::Node *links = doc.FindValue("Links"); if (links == NULL) NTA_THROW << "Invalid network structure file -- no links"; if (links->Type() != YAML::NodeType::Sequence) NTA_THROW << "Invalid network structure file -- links element is not a list"; for (YAML::Iterator link = links->begin(); link != links->end(); link++) { // Each link is a map -- extract the 5 values in the map if ((*link).Type() != YAML::NodeType::Map) NTA_THROW << "Invalid network structure file -- bad link (not a map)"; if ((*link).size() != 6) NTA_THROW << "Invalid network structure file -- bad link (wrong size)"; // 1. type node = (*link).FindValue("type"); if (node == NULL) NTA_THROW << "Invalid network structure file -- link does not have a type"; std::string linkType; *node >> linkType; // 2. params node = (*link).FindValue("params"); if (node == NULL) NTA_THROW << "Invalid network structure file -- link does not have params"; std::string params; *node >> params; // 3. srcRegion (name) node = (*link).FindValue("srcRegion"); if (node == NULL) NTA_THROW << "Invalid network structure file -- link does not have a srcRegion"; std::string srcRegionName; *node >> srcRegionName; // 4. srcOutput node = (*link).FindValue("srcOutput"); if (node == NULL) NTA_THROW << "Invalid network structure file -- link does not have a srcOutput"; std::string srcOutputName; *node >> srcOutputName; // 5. destRegion node = (*link).FindValue("destRegion"); if (node == NULL) NTA_THROW << "Invalid network structure file -- link does not have a destRegion"; std::string destRegionName; *node >> destRegionName; // 6. destInput node = (*link).FindValue("destInput"); if (node == NULL) NTA_THROW << "Invalid network structure file -- link does not have a destInput"; std::string destInputName; *node >> destInputName; if (!regions_.contains(srcRegionName)) NTA_THROW << "Invalid network structure file -- link specifies source region '" << srcRegionName << "' but no such region exists"; Region* srcRegion = regions_.getByName(srcRegionName); if (!regions_.contains(destRegionName)) NTA_THROW << "Invalid network structure file -- link specifies destination region '" << destRegionName << "' but no such region exists"; Region* destRegion = regions_.getByName(destRegionName); Output* srcOutput = srcRegion->getOutput(srcOutputName); if (srcOutput == NULL) NTA_THROW << "Invalid network structure file -- link specifies source output '" << srcOutputName << "' but no such name exists"; Input* destInput = destRegion->getInput(destInputName); if (destInput == NULL) NTA_THROW << "Invalid network structure file -- link specifies destination input '" << destInputName << "' but no such name exists"; // Create the link itself destInput->addLink(linkType, params, srcOutput); } // links }
void Network::read(NetworkProto::Reader& proto) { // Clear any previous regions while (regions_.getCount() > 0) { auto pair = regions_.getByIndex(0); delete pair.second; regions_.remove(pair.first); } // Add regions for (auto entry : proto.getRegions().getEntries()) { auto regionProto = entry.getValue(); auto region = addRegionFromProto(entry.getKey().cStr(), regionProto); // Initialize the phases for the region std::set<UInt32> phases; for (auto phase : regionProto.getPhases()) { phases.insert(phase); } setPhases_(region, phases); } // Add links. Note that we can't just pass the capnp struct to Link.read // because the linked input and output need references to the new link. for (auto linkProto : proto.getLinks()) { if (!regions_.contains(linkProto.getSrcRegion().cStr())) { NTA_THROW << "Link references unknown region: " << linkProto.getSrcRegion().cStr(); } Region* srcRegion = regions_.getByName(linkProto.getSrcRegion().cStr()); Output* srcOutput = srcRegion->getOutput(linkProto.getSrcOutput().cStr()); if (srcOutput == nullptr) { NTA_THROW << "Link references unknown source output: " << linkProto.getSrcOutput().cStr(); } if (!regions_.contains(linkProto.getDestRegion().cStr())) { NTA_THROW << "Link references unknown region: " << linkProto.getDestRegion().cStr(); } Region* destRegion = regions_.getByName(linkProto.getDestRegion().cStr()); Input* destInput = destRegion->getInput(linkProto.getDestInput().cStr()); if (destInput == nullptr) { NTA_THROW << "Link references unknown destination input: " << linkProto.getDestInput().cStr(); } // Actually create the link destInput->addLink( linkProto.getType().cStr(), linkProto.getParams().cStr(), srcOutput); } initialized_ = false; }