TYPED_TEST(NeuronLayerTest, TestPReLUConsistencyReLU) {
  typedef typename TypeParam::Dtype Dtype;
  LayerParameter prelu_layer_param;
  LayerParameter relu_layer_param;
  relu_layer_param.mutable_relu_param()->set_negative_slope(0.25);
  PReLULayer<Dtype> prelu(prelu_layer_param);
  ReLULayer<Dtype> relu(relu_layer_param);
  // Set up blobs
  vector<Blob<Dtype>*> blob_bottom_vec_2;
  vector<Blob<Dtype>*> blob_top_vec_2;
  shared_ptr<Blob<Dtype> > blob_bottom_2(new Blob<Dtype>());
  shared_ptr<Blob<Dtype> > blob_top_2(new Blob<Dtype>());
  blob_bottom_vec_2.push_back(blob_bottom_2.get());
  blob_top_vec_2.push_back(blob_top_2.get());
  blob_bottom_2->CopyFrom(*this->blob_bottom_, false, true);
  // SetUp layers
  prelu.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
  relu.SetUp(blob_bottom_vec_2, blob_top_vec_2);
  // Check forward
  prelu.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
  relu.Forward(this->blob_bottom_vec_, blob_top_vec_2);
  for (int s = 0; s < blob_top_2->count(); ++s) {
    EXPECT_EQ(this->blob_top_->cpu_data()[s], blob_top_2->cpu_data()[s]);
  }
  // Check backward
  shared_ptr<Blob<Dtype> > tmp_blob(new Blob<Dtype>());
  tmp_blob->ReshapeLike(*blob_top_2.get());
  FillerParameter filler_param;
  GaussianFiller<Dtype> filler(filler_param);
  filler.Fill(tmp_blob.get());
  caffe_copy(blob_top_2->count(), tmp_blob->cpu_data(),
      this->blob_top_->mutable_cpu_diff());
  caffe_copy(blob_top_2->count(), tmp_blob->cpu_data(),
      blob_top_2->mutable_cpu_diff());
  vector<bool> propagate_down;
  propagate_down.push_back(true);
  prelu.Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_);
  relu.Backward(blob_top_vec_2, propagate_down, blob_bottom_vec_2);
  for (int s = 0; s < blob_bottom_2->count(); ++s) {
    EXPECT_EQ(this->blob_bottom_->cpu_diff()[s], blob_bottom_2->cpu_diff()[s]);
  }
}
Exemple #2
0
void nnp_owt8x8_3x3_with_bias_with_relu__scalar(
	const float* transform,
	float* output,
	const float* bias,
	size_t transform_stride, size_t output_stride,
	uint32_t row_count, uint32_t column_count)
{
	transform_stride /= sizeof(float);
	const uint32_t row_offset = 0;
	const uint32_t column_offset = 0;

	float block[OUTPUT_SIZE][BLOCK_SIZE];
	for (uint32_t column = 0; column < BLOCK_SIZE; column++) {
		const float m0 = *transform;
		transform += transform_stride;
		float m1 = *transform;
		transform += transform_stride;
		const float m2 = *transform;
		transform += transform_stride;
		const float m3 = *transform;
		transform += transform_stride;
		const float m4 = *transform;
		transform += transform_stride;
		const float m5 = *transform;
		transform += transform_stride;
		const float m6 = *transform;
		transform += transform_stride;
		const float m7 = *transform;
		transform += transform_stride;

		if (column == 1) {
			const float bias_value = *bias;
			m1 += bias_value;
		}

		winograd_f6k3_output_transform(
			m0, m1, m2, m3, m4, m5, m6, m7,
			&block[0][column], &block[1][column], &block[2][column],
			&block[3][column], &block[4][column], &block[5][column]);
	}

	const uint32_t row_end = row_offset + row_count;
	for (uint32_t row = row_offset; row < row_end; row++) {
		float s0, s1, s2, s3, s4, s5;
		winograd_f6k3_output_transform(
			block[row][0], block[row][1], block[row][2], block[row][3],
			block[row][4], block[row][5], block[row][6], block[row][7],
			&s0, &s1, &s2, &s3, &s4, &s5);
		float* row_output = output + (row - row_offset) * output_stride;
		uint32_t remaining_column_count = column_count;
		switch (column_offset) {
			case 0:
				*row_output++ = relu(s0, 0.0f);
				if (--remaining_column_count == 0) {
					break;
				}
			case 1:
				*row_output++ = relu(s1, 0.0f);
				if (--remaining_column_count == 0) {
					break;
				}
			case 2:
				*row_output++ = relu(s2, 0.0f);
				if (--remaining_column_count == 0) {
					break;
				}
			case 3:
				*row_output++ = relu(s3, 0.0f);
				if (--remaining_column_count == 0) {
					break;
				}
			case 4:
				*row_output++ = relu(s4, 0.0f);
				if (--remaining_column_count == 0) {
					break;
				}
			case 5:
				*row_output = relu(s5, 0.0f);
				break;
			default:
				NNP_UNREACHABLE;
		}
	}
}