typedef vector<Blob*> B; TEST_CASE("TestGraph", "[Graph]") { Runnable test_graph; Op<Conv>* o = test_graph.create<Conv>("conv", "main", Conv::param_tuple(2, 2, 1, 1)); Blob* bottom = test_graph.create("bottom", 0, 0, {1, 3, 10, 10}); Blob* top = test_graph.create("top", 0, 0, {1, 3, 10, 10}); Blob* weight = test_graph.create("weight", 0, 0, {3, 3, 5, 5}); B{ bottom, weight } >> (*o) >> B{ top }; SECTION("graph node number test") { vector<Node*> nodes = test_graph.nodes(); int node_size = test_graph.nodes().size(); int source_size = test_graph.sources().size(); int sink_size = test_graph.sinks().size(); REQUIRE(node_size == 4); REQUIRE(source_size == 2); REQUIRE(sink_size == 1); } SECTION("graph sources and sinks test") { vector<Node*>&& nodes = test_graph.nodes(); // test sources REQUIRE(nodes[0]->is_source() == false); REQUIRE(nodes[1]->is_source() == true); REQUIRE(nodes[2]->is_source() == false); REQUIRE(nodes[3]->is_source() == true); // test sinks REQUIRE(nodes[0]->is_sink() == false); REQUIRE(nodes[1]->is_sink() == false); REQUIRE(nodes[2]->is_sink() == true);