void testBufferArgs(const BufferArgs& inputs, const std::vector<CheckBufferArg>& check) { EXPECT_EQ(inputs.size(), check.size()); for (size_t i = 0; i < inputs.size(); i++) { check[i](inputs[i]); } }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(1UL, inputs.size()); CHECK_EQ(1UL, outputs.size()); CHECK(inputs[0].isSequenceArg()) << "SequenceArg required here"; const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]); CHECK(in_seq.data() && in_seq.getSequenceId().data() && outputs[0].data()); CHECK_EQ(outputs[0].shape().ndims(), 2UL); CHECK_EQ(in_seq.shape().ndims(), 2UL); CHECK_EQ(in_seq.getSequenceId().shape().ndims(), 1UL); CHECK_EQ(in_seq.shape()[0], outputs[0].shape()[0]); /// output layer grad dim == weight dim * context_length_ CHECK_EQ(in_seq.shape()[1], outputs[0].shape()[1] * context_length_); CHECK_EQ(outputs[0].getArgType(), ADD_TO); const auto seq_vec = in_seq.getSequenceId().vector<int, Device>(); const auto out_grad_mat = in_seq.matrix<Device>(); auto w_grad_mat = outputs[0].matrix<Device>(); ContextProjectionBackwardWeight<Device>(out_grad_mat, w_grad_mat, seq_vec, context_length_, context_start_, total_pad_, begin_pad_); }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(1UL, inputs.size()); CHECK_EQ(1UL, outputs.size()); CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg()) << "SequenceArg required here"; const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]); const auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]); CHECK(in_seq.data() && out_seq.data() && in_seq.getSequenceId().data()); CHECK_EQ(out_seq.shape().ndims(), 2UL); CHECK_EQ(in_seq.shape().ndims(), 2UL); CHECK_EQ(in_seq.getSequenceId().shape().ndims(), 1UL); /// output layer grad dim == input layer grad dim * context_length_ CHECK_EQ(in_seq.shape().ndims(), out_seq.shape().ndims() * context_length_); /// input and output has the same batch_size CHECK_EQ(in_seq.shape()[0], out_seq.shape()[0]); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); const auto out_grad_mat = in_seq.matrix<Device>(); const auto seq_vec = in_seq.getSequenceId().vector<int, Device>(); auto in_grad_mat = out_seq.matrix<Device>(); ContextProjectionBackwardData<Device>( out_grad_mat, in_grad_mat, seq_vec, context_length_, context_start_); }
void check(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); CHECK(inputs[0].shape() == outputs[0].shape()); CHECK(inputs[0].shape() == outputs[1].shape()); }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(1UL, inputs.size()); CHECK(1UL == outputs.size() || 2UL == outputs.size()); CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg()) << "SequenceArg required here"; const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]); auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]); CHECK(in_seq.data() && in_seq.getSequenceId().data()); CHECK_EQ(in_seq.shape().ndims(), 2UL); CHECK_EQ(out_seq.shape().ndims(), 2UL); CHECK_EQ(out_seq.getSequenceId().shape().ndims(), 1UL); /// input and output grad has the same batch_size CHECK_EQ(out_seq.shape()[0], in_seq.shape()[0]); /// dim of output grad = dim of input grad * context_length CHECK_EQ(in_seq.shape()[1], out_seq.shape()[1] * context_length_); CHECK_EQ(out_seq.getArgType(), ADD_TO); if (2UL == outputs.size()) { CHECK_EQ(outputs[1].shape().ndims(), 2UL); /// dim of input grad == dim of weight CHECK_EQ(out_seq.shape()[1], outputs[1].shape()[1]); CHECK_EQ(outputs[1].getArgType(), ADD_TO); } const auto seq_vec = in_seq.getSequenceId().vector<int, Device>(); const auto out_grad_mat = in_seq.matrix<Device>(); auto in_grad_mat = !out_seq.data() ? typename Tensor<real, Device>::Matrix(nullptr, 0, 0) : out_seq.matrix<Device>(); auto w_grad_mat = (2UL == outputs.size() && outputs[1].data()) ? outputs[1].matrix<Device>() : typename Tensor<real, Device>::Matrix(nullptr, 0, 0); ContextProjectionBackward<Device>(out_grad_mat, in_grad_mat, w_grad_mat, seq_vec, context_length_, context_start_, begin_pad_, is_padding_, total_pad_); }
// Only need the shape of one input, can calculate the // floating-point operation. size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_LT((size_t)1, inputs.size()); size_t batchSize = inputs[0].shape()[0]; size_t maps = inputs[0].shape()[1]; size_t rows = inputs[0].shape()[2]; size_t columns = inputs[0].shape()[3]; // number of floating-point operations // an approximate value size_t ops = batchSize * maps * rows * columns * (size_ * 4 + 2); return ops; }
void testBufferArgs(const BufferArgs& inputs, const CheckBufferArg& check) { EXPECT_EQ(inputs.size(), 1U); check(inputs[0]); }