Example #1
0
/** Validate every local entry of a matrix has a given value. */
void validate_mat(DistMat& mat, DataType expected) {
  for (int i = 0; i < mat.LocalHeight(); ++i) {
    for (int j = 0; j < mat.LocalWidth(); ++j) {
      ASSERT_EQ(mat.GetLocal(i, j), expected);
    }
  }
}
Example #2
0
static bool writeDist(int fd, const char* filename, const DistMat& M, uint64_t* bytes)
{
    struct layer_header header;
    header.rank        = (uint64_t) M.Grid().Rank();
    header.width       = (uint64_t) M.Width();
    header.height      = (uint64_t) M.Height();
    header.localwidth  = (uint64_t) M.LocalWidth();
    header.localheight = (uint64_t) M.LocalHeight();
    header.ldim        = (uint64_t) M.LDim();

    ssize_t write_rc = write(fd, &header, sizeof(header));
    if (write_rc != sizeof(header)) {
        // error!
    }
    *bytes += write_rc;

    const Int localHeight = M.LocalHeight();
    const Int localWidth = M.LocalWidth();
    const Int lDim = M.LDim();
    if(localHeight == lDim) {
        void* buf = (void*) M.LockedBuffer();
        size_t bufsize = localHeight * localWidth * sizeof(DataType);
        write_rc = write(fd, buf, bufsize);
        if (write_rc != bufsize) {
            // error!
        }
        *bytes += write_rc;
    } else {
        for(Int j = 0; j < localWidth; ++j) {
            void* buf = (void*) M.LockedBuffer(0, j);
            size_t bufsize = localHeight * sizeof(DataType);
            write_rc = write(fd, buf, bufsize);
            if (write_rc != bufsize) {
                // error!
            }
            *bytes += write_rc;
        }
    }

    return true;
}
Example #3
0
static bool readDist(int fd, const char* filename, DistMat& M, uint64_t* bytes)
{
    struct layer_header header;
    ssize_t read_rc = read(fd, &header, sizeof(header));
    if (read_rc != sizeof(header)) {
        // error!
    }
    *bytes += read_rc;

    // check that header values match up

    Int height = header.height;
    Int width  = header.width;
    M.Resize(height, width);

    if(M.ColStride() == 1 && M.RowStride() == 1) {
        if(M.Height() == M.LDim()) {
            void* buf = (void*) M.Buffer();
            size_t bufsize = height * width * sizeof(DataType);
            read_rc = read(fd, buf, bufsize);
            if (read_rc != bufsize) {
                // error!
            }
            *bytes += read_rc;
        } else {
            for(Int j = 0; j < width; ++j) {
                void* buf = (void*) M.Buffer(0, j);
                size_t bufsize = height * sizeof(DataType);
                read_rc = read(fd, buf, bufsize);
                if (read_rc != bufsize) {
                    // error!
                }
                *bytes += read_rc;
            }
        }
    } else {
        const Int localHeight = M.LocalHeight();
        const Int localWidth = M.LocalWidth();
        const Int lDim = M.LDim();
        if(localHeight == lDim) {
            void* buf = (void*) M.Buffer();
            size_t bufsize = localHeight * localWidth * sizeof(DataType);
            read_rc = read(fd, buf, bufsize);
            if (read_rc != bufsize) {
                // error!
            }
            *bytes += read_rc;
        } else {
            for(Int jLoc = 0; jLoc < localWidth; ++jLoc) {
                void* buf = (void*) M.Buffer(0, jLoc);
                size_t bufsize = localHeight * sizeof(DataType);
                read_rc = read(fd, buf, bufsize);
                if (read_rc != bufsize) {
                    // error!
                }
                *bytes += read_rc;
            }
        }
    }
    return true;
}