/** Validate every local entry of a matrix has a given value. */ void validate_mat(DistMat& mat, DataType expected) { for (int i = 0; i < mat.LocalHeight(); ++i) { for (int j = 0; j < mat.LocalWidth(); ++j) { ASSERT_EQ(mat.GetLocal(i, j), expected); } } }
static bool writeDist(int fd, const char* filename, const DistMat& M, uint64_t* bytes) { struct layer_header header; header.rank = (uint64_t) M.Grid().Rank(); header.width = (uint64_t) M.Width(); header.height = (uint64_t) M.Height(); header.localwidth = (uint64_t) M.LocalWidth(); header.localheight = (uint64_t) M.LocalHeight(); header.ldim = (uint64_t) M.LDim(); ssize_t write_rc = write(fd, &header, sizeof(header)); if (write_rc != sizeof(header)) { // error! } *bytes += write_rc; const Int localHeight = M.LocalHeight(); const Int localWidth = M.LocalWidth(); const Int lDim = M.LDim(); if(localHeight == lDim) { void* buf = (void*) M.LockedBuffer(); size_t bufsize = localHeight * localWidth * sizeof(DataType); write_rc = write(fd, buf, bufsize); if (write_rc != bufsize) { // error! } *bytes += write_rc; } else { for(Int j = 0; j < localWidth; ++j) { void* buf = (void*) M.LockedBuffer(0, j); size_t bufsize = localHeight * sizeof(DataType); write_rc = write(fd, buf, bufsize); if (write_rc != bufsize) { // error! } *bytes += write_rc; } } return true; }
static bool readDist(int fd, const char* filename, DistMat& M, uint64_t* bytes) { struct layer_header header; ssize_t read_rc = read(fd, &header, sizeof(header)); if (read_rc != sizeof(header)) { // error! } *bytes += read_rc; // check that header values match up Int height = header.height; Int width = header.width; M.Resize(height, width); if(M.ColStride() == 1 && M.RowStride() == 1) { if(M.Height() == M.LDim()) { void* buf = (void*) M.Buffer(); size_t bufsize = height * width * sizeof(DataType); read_rc = read(fd, buf, bufsize); if (read_rc != bufsize) { // error! } *bytes += read_rc; } else { for(Int j = 0; j < width; ++j) { void* buf = (void*) M.Buffer(0, j); size_t bufsize = height * sizeof(DataType); read_rc = read(fd, buf, bufsize); if (read_rc != bufsize) { // error! } *bytes += read_rc; } } } else { const Int localHeight = M.LocalHeight(); const Int localWidth = M.LocalWidth(); const Int lDim = M.LDim(); if(localHeight == lDim) { void* buf = (void*) M.Buffer(); size_t bufsize = localHeight * localWidth * sizeof(DataType); read_rc = read(fd, buf, bufsize); if (read_rc != bufsize) { // error! } *bytes += read_rc; } else { for(Int jLoc = 0; jLoc < localWidth; ++jLoc) { void* buf = (void*) M.Buffer(0, jLoc); size_t bufsize = localHeight * sizeof(DataType); read_rc = read(fd, buf, bufsize); if (read_rc != bufsize) { // error! } *bytes += read_rc; } } } return true; }