void selectScalarTest(const dim4 &dims) { if (noDoubleTests<T>()) return; dtype ty = (dtype)dtype_traits<T>::af_type; array a = randu(dims, ty); array cond = randu(dims, ty) > a; double b = 3; if (a.isinteger()) { a = (a % (1 << 30)).as(ty); } array c = is_right ? select(cond, a, b) : select(cond, b, a); int num = (int)a.elements(); vector<T> ha(num); vector<T> hc(num); vector<char> hcond(num); a.host(&ha[0]); c.host(&hc[0]); cond.host(&hcond[0]); if (is_right) { for (int i = 0; i < num; i++) { ASSERT_EQ(hc[i], hcond[i] ? ha[i] : T(b)); } } else { for (int i = 0; i < num; i++) { ASSERT_EQ(hc[i], hcond[i] ? T(b) : ha[i]); } } }
void selectTest(const dim4 &dims) { if (noDoubleTests<T>()) return; dtype ty = (dtype)dtype_traits<T>::af_type; array a = randu(dims, ty); array b = randu(dims, ty); if (a.isinteger()) { a = (a % (1 << 30)).as(ty); b = (b % (1 << 30)).as(ty); } array cond = randu(dims, ty) > a; array c = select(cond, a, b); int num = (int)a.elements(); vector<T> ha(num); vector<T> hb(num); vector<T> hc(num); vector<char> hcond(num); a.host(&ha[0]); b.host(&hb[0]); c.host(&hc[0]); cond.host(&hcond[0]); for (int i = 0; i < num; i++) { ASSERT_EQ(hc[i], hcond[i] ? ha[i] : hb[i]); } }
TEST(Select, ISSUE_1249) { dim4 dims(2, 3, 4); array cond = randu(dims) > 0.5; array a = randu(dims); array b = select(cond, a - a * 0.9, a); array c = a - a * cond * 0.9; int num = (int)dims.elements(); vector<float> hb(num); vector<float> hc(num); b.host(&hb[0]); c.host(&hc[0]); for (int i = 0; i < num; i++) { EXPECT_NEAR(hc[i], hb[i], 1e-7) << "at " << i; } }
TEST(Select, NaN) { dim4 dims(1000, 1250); dtype ty = f32; array a = randu(dims, ty); a(seq(a.dims(0) / 2), span, span, span) = NaN; float b = 0; array c = select(isNaN(a), b, a); int num = (int)a.elements(); vector<float> ha(num); vector<float> hc(num); a.host(&ha[0]); c.host(&hc[0]); for (int i = 0; i < num; i++) { ASSERT_FLOAT_EQ(hc[i], std::isnan(ha[i]) ? b : ha[i]); } }
vector<float> hc(num); b.host(&hb[0]); c.host(&hc[0]); for (int i = 0; i < num; i++) { EXPECT_NEAR(hc[i], hb[i], 1e-7) << "at " << i; } } TEST(Select, 4D) { dim4 dims(2, 3, 4, 2); array cond = randu(dims) > 0.5; array a = randu(dims); array b = select(cond, a - a * 0.9, a); array c = a - a * cond * 0.9; int num = (int)dims.elements(); vector<float> hb(num); vector<float> hc(num); b.host(&hb[0]); c.host(&hc[0]); for (int i = 0; i < num; i++) { EXPECT_NEAR(hc[i], hb[i], 1e-7) << "at " << i; } } TEST(Select, Issue_1730)