Beispiel #1
0
void Network::setDefaultPhase_(Region* region)
{
  UInt32 newphase = phaseInfo_.size();
  std::set<UInt32> phases;
  phases.insert(newphase);
  setPhases_(region, phases);
}
Beispiel #2
0
void
Network::setPhases(const std::string& name, std::set<UInt32>& phases)
{
  if (! regions_.contains(name))
    NTA_THROW << "setPhases -- no region exists with name '" << name << "'";

  Region *r = regions_.getByName(name);
  setPhases_(r, phases);
}
Beispiel #3
0
void Network::loadFromBundle(const std::string& name)
{
  if (! StringUtils::endsWith(name, ".nta"))
    NTA_THROW << "loadFromBundle: bundle extension must be \".nta\"";

  std::string fullPath = Path::normalize(Path::makeAbsolute(name));

  if (! Path::exists(fullPath))
    NTA_THROW << "Path " << fullPath << " does not exist";

  std::string networkStructureFilename = Path::join(fullPath, "network.yaml");
  std::ifstream f(networkStructureFilename.c_str());
  YAML::Parser parser(f);
  YAML::Node doc;
  bool success = parser.GetNextDocument(doc);
  if (!success)
    NTA_THROW << "Unable to find YAML document in network structure file " 
              << networkStructureFilename;

  if (doc.Type() != YAML::NodeType::Map)
    NTA_THROW << "Invalid network structure file -- does not contain a map";

  // Should contain Version, Regions, Links
  if (doc.size() != 3)
    NTA_THROW << "Invalid network structure file -- contains " 
              << doc.size() << " elements";

  // Extra version
  const YAML::Node *node = doc.FindValue("Version");
  if (node == NULL)
    NTA_THROW << "Invalid network structure file -- no version";
  
  int version;
  *node >> version;
  if (version != 2)
    NTA_THROW << "Invalid network structure file -- only version 2 supported";
  
  // Regions
  const YAML::Node *regions = doc.FindValue("Regions");
  if (regions == NULL)
    NTA_THROW << "Invalid network structure file -- no regions";

  if (regions->Type() != YAML::NodeType::Sequence)
    NTA_THROW << "Invalid network structure file -- regions element is not a list";
  
  for (YAML::Iterator region = regions->begin(); region != regions->end(); region++)
  {
    // Each region is a map -- extract the 5 values in the map
    if ((*region).Type() != YAML::NodeType::Map)
      NTA_THROW << "Invalid network structure file -- bad region (not a map)";
    
    if ((*region).size() != 5)
      NTA_THROW << "Invalid network structure file -- bad region (wrong size)";
    
    // 1. name
    node = (*region).FindValue("name");
    if (node == NULL)
      NTA_THROW << "Invalid network structure file -- region has no name";
    std::string name;
    *node >> name;

    // 2. nodeType
    node = (*region).FindValue("nodeType");
    if (node == NULL)
      NTA_THROW << "Invalid network structure file -- region " 
                << name << " has no node type";
    std::string nodeType;
    *node >> nodeType;

    // 3. dimensions
    node = (*region).FindValue("dimensions");
    if (node == NULL)
      NTA_THROW << "Invalid network structure file -- region "
                << name << " has no dimensions";
    if ((*node).Type() != YAML::NodeType::Sequence)
      NTA_THROW << "Invalid network structure file -- region "
                << name << " dimensions specified incorrectly";
    Dimensions dimensions;
    for (YAML::Iterator valiter = (*node).begin(); valiter != (*node).end(); valiter++)
    {
      size_t val;
      (*valiter) >> val;
      dimensions.push_back(val);
    }

    // 4. phases
    node = (*region).FindValue("phases");
    if (node == NULL)
      NTA_THROW << "Invalid network structure file -- region"
                << name << "has no phases";
    if ((*node).Type() != YAML::NodeType::Sequence)
      NTA_THROW << "Invalid network structure file -- region "
                << name << " phases specified incorrectly";

    std::set<UInt32> phases;
    for (YAML::Iterator valiter = (*node).begin(); valiter != (*node).end(); valiter++)
    {
      UInt32 val;
      (*valiter) >> val;
      phases.insert(val);
    }
    
    // 5. label
    node = (*region).FindValue("label");
    if (node == NULL)
      NTA_THROW << "Invalid network structure file -- region"
                << name << "has no label";
    std::string label;
    *node >> label;
    
    Region *r = addRegionFromBundle(name, nodeType, dimensions, fullPath, label);
    setPhases_(r, phases);


  }

  const YAML::Node *links = doc.FindValue("Links");
  if (links == NULL)
    NTA_THROW << "Invalid network structure file -- no links";

  if (links->Type() != YAML::NodeType::Sequence)
    NTA_THROW << "Invalid network structure file -- links element is not a list";

  for (YAML::Iterator link = links->begin(); link != links->end(); link++)
  {
    // Each link is a map -- extract the 5 values in the map
    if ((*link).Type() != YAML::NodeType::Map)
      NTA_THROW << "Invalid network structure file -- bad link (not a map)";
    
    if ((*link).size() != 6)
      NTA_THROW << "Invalid network structure file -- bad link (wrong size)";
    
    // 1. type
    node = (*link).FindValue("type");
    if (node == NULL)
      NTA_THROW << "Invalid network structure file -- link does not have a type";
    std::string linkType;
    *node >> linkType;

    // 2. params
    node = (*link).FindValue("params");
    if (node == NULL)
      NTA_THROW << "Invalid network structure file -- link does not have params";
    std::string params;
    *node >> params;

    // 3. srcRegion (name)
    node = (*link).FindValue("srcRegion");
    if (node == NULL)
      NTA_THROW << "Invalid network structure file -- link does not have a srcRegion";
    std::string srcRegionName;
    *node >> srcRegionName;


    // 4. srcOutput
    node = (*link).FindValue("srcOutput");
    if (node == NULL)
      NTA_THROW << "Invalid network structure file -- link does not have a srcOutput";
    std::string srcOutputName;
    *node >> srcOutputName;

    // 5. destRegion
    node = (*link).FindValue("destRegion");
    if (node == NULL)
      NTA_THROW << "Invalid network structure file -- link does not have a destRegion";
    std::string destRegionName;
    *node >> destRegionName;

    // 6. destInput
    node = (*link).FindValue("destInput");
    if (node == NULL)
      NTA_THROW << "Invalid network structure file -- link does not have a destInput";
    std::string destInputName;
    *node >> destInputName;

    if (!regions_.contains(srcRegionName))
      NTA_THROW << "Invalid network structure file -- link specifies source region '" << srcRegionName << "' but no such region exists";
    Region* srcRegion = regions_.getByName(srcRegionName);

    if (!regions_.contains(destRegionName))
      NTA_THROW << "Invalid network structure file -- link specifies destination region '" << destRegionName << "' but no such region exists";
    Region* destRegion = regions_.getByName(destRegionName);

    Output* srcOutput = srcRegion->getOutput(srcOutputName);
    if (srcOutput == NULL)
      NTA_THROW << "Invalid network structure file -- link specifies source output '" << srcOutputName << "' but no such name exists";

    Input* destInput = destRegion->getInput(destInputName);
    if (destInput == NULL)
      NTA_THROW << "Invalid network structure file -- link specifies destination input '" << destInputName << "' but no such name exists";

    // Create the link itself
    destInput->addLink(linkType, params, srcOutput);


  } // links

}
Beispiel #4
0
void Network::read(NetworkProto::Reader& proto)
{
  // Clear any previous regions
  while (regions_.getCount() > 0)
  {
    auto pair = regions_.getByIndex(0);
    delete pair.second;
    regions_.remove(pair.first);
  }

  // Add regions
  for (auto entry : proto.getRegions().getEntries())
  {
    auto regionProto = entry.getValue();
    auto region = addRegionFromProto(entry.getKey().cStr(), regionProto);
    // Initialize the phases for the region
    std::set<UInt32> phases;
    for (auto phase : regionProto.getPhases())
    {
      phases.insert(phase);
    }
    setPhases_(region, phases);
  }

  // Add links. Note that we can't just pass the capnp struct to Link.read
  // because the linked input and output need references to the new link.
  for (auto linkProto : proto.getLinks())
  {
    if (!regions_.contains(linkProto.getSrcRegion().cStr()))
    {
      NTA_THROW << "Link references unknown region: "
                << linkProto.getSrcRegion().cStr();
    }
    Region* srcRegion = regions_.getByName(linkProto.getSrcRegion().cStr());
    Output* srcOutput = srcRegion->getOutput(linkProto.getSrcOutput().cStr());
    if (srcOutput == nullptr)
    {
      NTA_THROW << "Link references unknown source output: "
                << linkProto.getSrcOutput().cStr();
    }

    if (!regions_.contains(linkProto.getDestRegion().cStr()))
    {
      NTA_THROW << "Link references unknown region: "
                << linkProto.getDestRegion().cStr();
    }
    Region* destRegion = regions_.getByName(linkProto.getDestRegion().cStr());
    Input* destInput = destRegion->getInput(linkProto.getDestInput().cStr());
    if (destInput == nullptr)
    {
      NTA_THROW << "Link references unknown destination input: "
                << linkProto.getDestInput().cStr();
    }

    // Actually create the link
    destInput->addLink(
        linkProto.getType().cStr(), linkProto.getParams().cStr(), srcOutput);
  }

  initialized_ = false;
}