void MalisLossLayer<Dtype>::Malis(const Dtype* conn_data, const int_tp conn_num_dims, const int_tp* conn_dims, const int_tp* nhood_data, const int_tp* nhood_dims, const Dtype* seg_data, const bool pos, Dtype* dloss_data, Dtype* loss_out, Dtype *classerr_out, Dtype *rand_index_out) { if ((nhood_dims[1] != (conn_num_dims - 1)) || (nhood_dims[0] != conn_dims[0])) { LOG(FATAL) << "nhood and conn dimensions don't match" << " (" << nhood_dims[1] << " vs. " << (conn_num_dims - 1) << " and " << nhood_dims[0] << " vs. " << conn_dims[conn_num_dims - 1] <<")"; } /* Cache for speed to access neighbors */ // nVert stores (x * y * z) int64_t nVert = 1; for (int64_t i = 1; i < conn_num_dims; ++i) { nVert *= conn_dims[i]; // std::cout << i << " nVert: " << nVert << std::endl; } // prodDims stores x, x*y, x*y*z offsets std::vector<int64_t> prodDims(conn_num_dims - 1); prodDims[conn_num_dims - 2] = 1; for (int64_t i = 1; i < conn_num_dims - 1; ++i) { prodDims[conn_num_dims - 2 - i] = prodDims[conn_num_dims - 1 - i] * conn_dims[conn_num_dims - i]; // std::cout << conn_num_dims - 2 - i << " dims: " // << prodDims[conn_num_dims - 2 - i] << std::endl; } /* convert n-d offset vectors into linear array offset scalars */ // nHood is a vector of size #edges std::vector<int32_t> nHood(nhood_dims[0]); for (int64_t i = 0; i < nhood_dims[0]; ++i) { nHood[i] = 0; for (int64_t j = 0; j < nhood_dims[1]; ++j) { nHood[i] += (int32_t) nhood_data[j + i * nhood_dims[1]] * prodDims[j]; } // std::cout << i << " nHood: " << nHood[i] << std::endl; } /* Disjoint sets and sparse overlap vectors */ std::vector<std::map<int64_t, int64_t> > overlap(nVert); std::vector<int64_t> rank(nVert); std::vector<int64_t> parent(nVert); std::map<int64_t, int64_t> segSizes; int64_t nLabeledVert = 0; int64_t nPairPos = 0; boost::disjoint_sets<int64_t*, int64_t*> dsets(&rank[0], &parent[0]); // Loop over all seg data items for (int64_t i = 0; i < nVert; ++i) { dsets.make_set(i); if (0 != seg_data[i]) { overlap[i].insert(std::pair<int64_t, int64_t>(seg_data[i], 1)); ++nLabeledVert; ++segSizes[seg_data[i]]; nPairPos += (segSizes[seg_data[i]] - 1); } } int64_t nPairTot = (nLabeledVert * (nLabeledVert - 1)) / 2; int64_t nPairNeg = nPairTot - nPairPos; int64_t nPairNorm; if (pos) { nPairNorm = nPairPos; } else { nPairNorm = nPairNeg; } int64_t edgeCount = 0; // Loop over #edges for (int64_t d = 0, i = 0; d < conn_dims[0]; ++d) { // Loop over Z for (int64_t z = 0; z < conn_dims[1]; ++z) { // Loop over Y for (int64_t y = 0; y < conn_dims[2]; ++y) { // Loop over X for (int64_t x = 0; x < conn_dims[3]; ++x, ++i) { // Out-of-bounds check: if (!((z + nhood_data[d * nhood_dims[1] + 0] < 0) ||(z + nhood_data[d * nhood_dims[1] + 0] >= conn_dims[1]) ||(y + nhood_data[d * nhood_dims[1] + 1] < 0) ||(y + nhood_data[d * nhood_dims[1] + 1] >= conn_dims[2]) ||(x + nhood_data[d * nhood_dims[1] + 2] < 0) ||(x + nhood_data[d * nhood_dims[1] + 2] >= conn_dims[3]))) { ++edgeCount; } } } } } /* Sort all the edges in increasing order of weight */ std::vector<int64_t> pqueue(edgeCount); int64_t j = 0; // Loop over #edges for (int64_t d = 0, i = 0; d < conn_dims[0]; ++d) { // Loop over Z for (int64_t z = 0; z < conn_dims[1]; ++z) { // Loop over Y for (int64_t y = 0; y < conn_dims[2]; ++y) { // Loop over X for (int64_t x = 0; x < conn_dims[3]; ++x, ++i) { // Out-of-bounds check: if (!((z + nhood_data[d * nhood_dims[1] + 0] < 0) ||(z + nhood_data[d * nhood_dims[1] + 0] >= conn_dims[1]) ||(y + nhood_data[d * nhood_dims[1] + 1] < 0) ||(y + nhood_data[d * nhood_dims[1] + 1] >= conn_dims[2]) ||(x + nhood_data[d * nhood_dims[1] + 2] < 0) ||(x + nhood_data[d * nhood_dims[1] + 2] >= conn_dims[3]))) { pqueue[j++] = i; } } } } } pqueue.resize(j); std::sort(pqueue.begin(), pqueue.end(), MalisAffinityGraphCompare<Dtype>(conn_data)); /* Start MST */ int64_t minEdge; int64_t e, v1, v2; int64_t set1, set2; int64_t nPair = 0; double loss = 0, dl = 0; int64_t nPairIncorrect = 0; std::map<int64_t, int64_t>::iterator it1, it2; /* Start Kruskal's */ for (int64_t i = 0; i < pqueue.size(); ++i) { minEdge = pqueue[i]; // nVert = x * y * z, minEdge in [0, x * y * z * #edges] // e: edge dimension e = minEdge / nVert; // v1: node at edge beginning v1 = minEdge % nVert; // v2: neighborhood node at edge e v2 = v1 + nHood[e]; // std::cout << "V1: " << v1 << ", V2: " << v2 << std::endl; set1 = dsets.find_set(v1); set2 = dsets.find_set(v2); if (set1 != set2) { dsets.link(set1, set2); /* compute the dloss for this MST edge */ for (it1 = overlap[set1].begin(); it1 != overlap[set1].end(); ++it1) { for (it2 = overlap[set2].begin(); it2 != overlap[set2].end(); ++it2) { nPair = it1->second * it2->second; if (pos && (it1->first == it2->first)) { // +ve example pairs dl = (Dtype(1.0) - conn_data[minEdge]); loss += dl * dl * nPair; // Use hinge loss dloss_data[minEdge] += dl * nPair; if (conn_data[minEdge] <= Dtype(0.5)) { // an error nPairIncorrect += nPair; } } else if ((!pos) && (it1->first != it2->first)) { // -ve example pairs dl = (-conn_data[minEdge]); loss += dl * dl * nPair; // Use hinge loss dloss_data[minEdge] += dl * nPair; if (conn_data[minEdge] > Dtype(0.5)) { // an error nPairIncorrect += nPair; } } } } if (nPairNorm > 0) { dloss_data[minEdge] /= nPairNorm; } else { dloss_data[minEdge] = 0; } if (dsets.find_set(set1) == set2) { std::swap(set1, set2); } for (it2 = overlap[set2].begin(); it2 != overlap[set2].end(); ++it2) { it1 = overlap[set1].find(it2->first); if (it1 == overlap[set1].end()) { overlap[set1].insert(pair<int64_t, int64_t> (it2->first, it2->second)); } else { it1->second += it2->second; } } overlap[set2].clear(); } // end link } // end while /* Return items */ double classerr, randIndex; if (nPairNorm > 0) { loss /= nPairNorm; } else { loss = 0; } // std::cout << "nPairIncorrect: " << nPairIncorrect << std::endl; // std::cout << "nPairNorm: " << nPairNorm << std::endl; *loss_out = loss; classerr = static_cast<double>(nPairIncorrect) / static_cast<double>(nPairNorm); *classerr_out = classerr; randIndex = 1.0 - static_cast<double>(nPairIncorrect) / static_cast<double>(nPairNorm); *rand_index_out = randIndex; }
TEST(Graph_Test, Standard) { { // 0--1 2 // | |_/| // 3--4 5 enum { vN=6, v0, v1, v2, v3, v4, v5 }; typedef std::pair<int, int> Edge_type; std::vector<Edge_type> eVec; eVec.emplace_back(v0, v1); eVec.emplace_back(v2, v4); eVec.emplace_back(v2, v5); eVec.emplace_back(v0, v3); eVec.emplace_back(v1, v4); eVec.emplace_back(v4, v3); int _bfs1[6] = {v0,v1,v3,v4,v2,v5}; int _bfs2[6] = {v0,v3,v1,v4,v2,v5}; std::vector<int> bfsexp1(_bfs1, _bfs1+6); std::vector<int> bfsexp2(_bfs2, _bfs2+6); int _dfs1[6] = {v0,v1,v4,v2,v5,v3}; int _dfs2[6] = {v0,v1,v4,v3,v2,v5}; int _dfs3[6] = {v0,v3,v4,v2,v5,v1}; int _dfs4[6] = {v0,v3,v4,v1,v2,v5}; std::vector<int> dfsexp1(_dfs1, _dfs1+6); std::vector<int> dfsexp2(_dfs2, _dfs2+6); std::vector<int> dfsexp3(_dfs3, _dfs3+6); std::vector<int> dfsexp4(_dfs4, _dfs4+6); std::vector<int> buff; auto dummyfun = [](int){}; auto printfun = [&](int v){buff.push_back(v);}; misc::Graph<int,misc::undirected,misc::keyInt0> g(6); g.addEdge(v0,v1); g.addEdge(v2,v4); g.addEdge(v2,v5); g.addEdge(v0,v3); g.addEdge(v1,v4); g.addEdge(v4,v3); g.BFS(printfun); EXPECT_TRUE(bfsexp1 == buff || bfsexp2 == buff); buff.clear(); g.DFS(printfun, dummyfun); EXPECT_TRUE(dfsexp1 == buff || dfsexp2 == buff || dfsexp3 == buff || dfsexp4 == buff); buff.clear(); g.clear(); } ////////////////////////////////////////////////////////////////////////// { // case from CLRS Figure 22.9 misc::Graph<char, misc::directed> g; char de[28] = {'o','1','1','2','2','3','3','2','4','o','1','4','1','5', '2','6','3','7','4','5','5','6','6','5','6','7','7','7'}; for (int i=0; i<28; i+=2) g.addEdge(de[i], de[i+1]); EXPECT_EQ(g.SCC(), 4); g.clear(); } ////////////////////////////////////////////////////////////////////////// { // case from CLRS Figure 23.4 misc::Graph<char, misc::undirected> g; char e[14*2] = {'a','b','a','h','b','h','b','c','c','i','c','f','c','d', 'd','e','d','f','e','f','f','g','g','i','g','h','h','i'}; float w[14] = {4,8,11,8,2,4,7,9,14,10,2,6,1,7}; for (int i=0; i<14; i++) g.addEdge(e[2*i], e[2*i+1], w[i]); auto printfun = [](char v, char u){std::cout<<"("<<v<<","<<u<<") ";}; EXPECT_EQ(g.Kruskal_MST(printfun), 37.f); //EXPECT_EQ(g.Prim_MST(printfun), 37.f); printf("\n"); g.clear(); } ////////////////////////////////////////////////////////////////////////// { // case from CLRS Figure 24.5, DAG // |-----6-----|------1------| // r--5--s--2--t--7--x-(-1)-y-(-2)-z // |-----3-----|-----4------| // |---------2---------| // has negative distance misc::Graph<char, misc::directed> g; char de[2*10] = {'t','x','t','y','t','z', 'x','y','x','z', 'y','z', 'r','s','r','t','s','t','s','x',}; float d[10] = {7,4,2, -1,1, -2, 5,3, 2,6}; for (int i=0; i<10; i++) g.addEdge(de[2*i], de[2*i+1], 1, d[i]); char _tp[6] = {'r','s','t','x','y','z'}; std::vector<char> tpexp(_tp, _tp+6); std::vector<char> buff; auto printfun = [&](char v){buff.push_back(v);}; g.topological_sort(printfun); EXPECT_TRUE(tpexp == buff); buff.clear(); EXPECT_EQ(g.DAGShortestPath('s', 'z'), 3); g.clear(); // case from CLRS Figure 24.6, nonnegative // t x // s // y z // has cycle char de1[2*10] = {'s','t','s','y','t','y','y','t', 't','x','y','z','y','x', 'x','z','z','x', 'z','s'}; float d1[10] = {10,5,2,3, 1,2,9, 4,6, 7}; for (int i=0; i<10; i++) g.addEdge(de1[2*i], de1[2*i+1], 1, d1[i]); EXPECT_EQ(g.Dijkstra('s','x'), 9); g.clear(); // case from CLRS Figure 24.4, Bellman-Ford // has cycle and negative distance char de2[2*10] = {'s','t','s','y','t','y', 't','x','x','t','t','z','y','z','y','x', 'z','x','z','s'}; float d2[10] = {6,7,8, 5,-2,-4,9,-3, 7,2}; for (int i=0; i<10; i++) g.addEdge(de2[2*i], de2[2*i+1], 1, d2[i]); EXPECT_EQ(g.Bellman_Ford('s','z'), -2); g.clear(); } ////////////////////////////////////////////////////////////////////////// { // case from CLRS Figure 25.4 misc::Graph<char, misc::directed, misc::keyMap, float, int> g; char de[2*9] = {'1','2','1','3','1','5', '2','4','2','5', '3','2', '4','1','4','3', '5','4'}; int d[9] = {3,8,-4, 1,7, 4, 2,-5, 6}; for (int i=0; i<9; i++) g.addEdge(de[2*i], de[2*i+1], 1, d[i]); int expd[5][5] = { { 0, 1,-3, 2,-4}, { 3, 0,-4, 1,-1}, { 7, 4, 0, 5, 3}, { 2,-1,-5, 0,-2}, { 8, 5, 1, 6, 0} }; std::vector<std::vector<int>> D; g.Floyd_Warshall(D); for(char i='1'; i<='5'; i++) for(char j='1'; j<='5'; j++) EXPECT_EQ(D[g.Id(i)][g.Id(j)], expd[i-'1'][j-'1']); } ////////////////////////////////////////////////////////////////////////// { // case from CLRS Figure 26.6 misc::Graph<char, misc::directed, misc::keyMap, int> g; char de[2*9] = {'s','1','s','2','2','1', '1','3','2','4', '3','2', '4','3','4','t','3','t'}; int w[9] = {16,13,4, 12,14, 9, 7,4,20}; for (int i=0; i<9; i++) g.addEdge(de[2*i], de[2*i+1], w[i]); EXPECT_EQ(g.Edmonds_Karp('s', 't'), 23); } ////////////////////////////////////////////////////////////////////////// { // case from http://algs4.cs.princeton.edu/15uf/UF.java.html char ds[10] = {'o','1','2','3','4','5','6','7','8','9'}; char dc[8*2] = {'4','3', '3','8', '6','5', '9','4', '2','1', '5','o', '7','2', '6','1'}; misc::DisjointSets<char> dsets(ds, ds+10); for (int i=0; i<8; i++) { dsets.link(dc[2*i], dc[2*i+1]); } EXPECT_EQ(dsets.getCount(), 2); } }