void Blob<Dtype>::FromProto(const BlobProto& proto, bool reshape) { if (reshape) { vector<int> shape; if (proto.has_num() || proto.has_channels() || proto.has_height() || proto.has_width()) { // Using deprecated 4D Blob dimensions -- // shape is (num, channels, height, width). shape.resize(4); shape[0] = proto.num(); shape[1] = proto.channels(); shape[2] = proto.height(); shape[3] = proto.width(); } else { shape.resize(proto.shape().dim_size()); for (int i = 0; i < proto.shape().dim_size(); ++i) { shape[i] = proto.shape().dim(i); } } Reshape(shape); } else { CHECK(ShapeEquals(proto)) << "shape mismatch (reshape not set)"; } // copy data Dtype* data_vec = mutable_cpu_data(); for (int i = 0; i < count_; ++i) { data_vec[i] = proto.data(i); } if (proto.diff_size() > 0) { Dtype* diff_vec = mutable_cpu_diff(); for (int i = 0; i < count_; ++i) { diff_vec[i] = proto.diff(i); } } }
bool Blob<Dtype>::ShapeEquals(const BlobProto& other) { if (other.has_num() || other.has_channels() || other.has_height() || other.has_width()) { // Using deprecated 4D Blob dimensions -- // shape is (num, channels, height, width). // Note: we do not use the normal Blob::num(), Blob::channels(), etc. // methods as these index from the beginning of the blob shape, where legacy // parameter blobs were indexed from the end of the blob shape (e.g., bias // Blob shape (1 x 1 x 1 x N), IP layer weight Blob shape (1 x 1 x M x N)). return shape_.size() <= 4 && LegacyShape(-4) == other.num() && LegacyShape(-3) == other.channels() && LegacyShape(-2) == other.height() && LegacyShape(-1) == other.width(); } vector<int> other_shape(other.shape().dim_size()); for (int i = 0; i < other.shape().dim_size(); ++i) { other_shape[i] = other.shape().dim(i); } return shape_ == other_shape; }
void Blob<Dtype>::FromProto(const BlobProto& proto, bool need_reshape = true){ //copy shape if (need_reshape){ vector<int> shape; shape.resize(proto.shape().dim_size()); for (int i = 0; i < shape.size(); i++) shape[i] = proto.shape().dim(i); reshape(shape); } if (proto.data_size()>0){ CHECK_EQ(proto.data_size(), count()); Dtype *data = mutable_cpu_data(); for (int i = 0; i < count_; i++) data[i] = proto.data(i); } if (proto.diff_size()>0){ CHECK_EQ(proto.diff_size(), count()); Dtype *diff = mutable_cpu_diff(); for (int i = 0; i < count_; i++) diff[i] = proto.diff(i); } }