Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/xsimd/arch/xsimd_avx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1613,7 +1613,7 @@ namespace xsimd
}
return split;
}
constexpr auto lane_mask = mask % make_batch_constant<uint32_t, (mask.size / 2), A>();
constexpr auto lane_mask = mask % std::integral_constant<uint32_t, (mask.size / 2)>();
XSIMD_IF_CONSTEXPR(detail::is_only_from_lo(mask))
{
__m256 broadcast = _mm256_permute2f128_ps(self, self, 0x00); // [low | low]
Expand All @@ -1632,15 +1632,15 @@ namespace xsimd
__m256 swapped = _mm256_permute2f128_ps(self, self, 0x01); // [high | low]

// normalize mask taking modulo 4
constexpr auto half_mask = mask % make_batch_constant<uint32_t, 4, A>();
constexpr auto half_mask = mask % std::integral_constant<uint32_t, 4>();

// permute within each lane
__m256 r0 = _mm256_permutevar_ps(self, half_mask.as_batch());
__m256 r1 = _mm256_permutevar_ps(swapped, half_mask.as_batch());

// select lane by the mask index divided by 4
constexpr auto lane = batch_constant<uint32_t, A, 0, 0, 0, 0, 1, 1, 1, 1> {};
constexpr int lane_idx = ((mask / make_batch_constant<uint32_t, 4, A>()) != lane).mask();
constexpr int lane_idx = ((mask / std::integral_constant<uint32_t, 4>()) != lane).mask();

return _mm256_blend_ps(r0, r1, lane_idx);
}
Expand Down Expand Up @@ -1681,7 +1681,7 @@ namespace xsimd

// select lane by the mask index divided by 2
constexpr auto lane = batch_constant<uint64_t, A, 0, 0, 1, 1> {};
constexpr int lane_idx = ((mask / make_batch_constant<uint64_t, 2, A>()) != lane).mask();
constexpr int lane_idx = ((mask / std::integral_constant<uint64_t, 2>()) != lane).mask();

// blend the two permutes
return _mm256_blend_pd(r0, r1, lane_idx);
Expand Down
4 changes: 2 additions & 2 deletions include/xsimd/arch/xsimd_avx2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ namespace xsimd
return self;
}

constexpr auto lane_mask = mask % make_batch_constant<uint8_t, (mask.size / 2), A>();
constexpr auto lane_mask = mask % std::integral_constant<uint8_t, (mask.size / 2)>();

XSIMD_IF_CONSTEXPR(!detail::is_cross_lane(mask))
{
Expand Down Expand Up @@ -1409,7 +1409,7 @@ namespace xsimd
}
XSIMD_IF_CONSTEXPR(!detail::is_cross_lane(mask))
{
constexpr auto lane_mask = mask % make_batch_constant<uint32_t, (mask.size / 2), A>();
constexpr auto lane_mask = mask % std::integral_constant<uint32_t, (mask.size / 2)>();
// Cheaper intrinsics when not crossing lanes
// Contrary to the uint64_t version, the limits of 8 bits for the immediate constant
// cannot make different permutations across lanes
Expand Down
2 changes: 1 addition & 1 deletion include/xsimd/arch/xsimd_sse2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2074,7 +2074,7 @@ namespace xsimd
__m128i hi = _mm_unpackhi_epi64(hil, hih);

// mask to choose the right lane
constexpr auto blend_mask = mask < make_batch_constant<uint16_t, 4, A>();
constexpr auto blend_mask = mask < std::integral_constant<uint16_t, 4>();

// blend the two permutes
return select(blend_mask, batch<uint16_t, A>(lo), batch<uint16_t, A>(hi));
Expand Down
53 changes: 43 additions & 10 deletions include/xsimd/types/xsimd_batch_constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,16 @@ namespace xsimd
}

public:
#define MAKE_BINARY_OP(OP, NAME) \
template <T... OtherValues> \
constexpr auto operator OP(batch_constant<T, A, OtherValues...> other) const \
{ \
return apply<NAME<void>>(*this, other); \
#define MAKE_BINARY_OP(OP, NAME) \
template <T... OtherValues> \
constexpr auto operator OP(batch_constant<T, A, OtherValues...> other) const \
{ \
return apply<NAME<void>>(*this, other); \
} \
template <T OtherValue> \
constexpr batch_constant<T, A, (Values OP OtherValue)...> operator OP(std::integral_constant<T, OtherValue>) const \
{ \
return {}; \
}

MAKE_BINARY_OP(+, std::plus)
Expand Down Expand Up @@ -350,11 +355,16 @@ namespace xsimd
return apply_bool<F, std::tuple<std::integral_constant<T, Values>...>, std::tuple<std::integral_constant<T, OtherValues>...>>(std::make_index_sequence<sizeof...(Values)>());
}

#define MAKE_BINARY_BOOL_OP(OP, NAME) \
template <T... OtherValues> \
constexpr auto operator OP(batch_constant<T, A, OtherValues...> other) const \
{ \
return apply_bool<NAME<void>>(*this, other); \
#define MAKE_BINARY_BOOL_OP(OP, NAME) \
template <T... OtherValues> \
constexpr auto operator OP(batch_constant<T, A, OtherValues...> other) const \
{ \
return apply_bool<NAME<void>>(*this, other); \
} \
template <T OtherValue> \
constexpr batch_bool_constant<T, A, (Values OP OtherValue)...> operator OP(std::integral_constant<T, OtherValue>) const \
{ \
return {}; \
}

MAKE_BINARY_BOOL_OP(==, std::equal_to)
Expand Down Expand Up @@ -483,6 +493,29 @@ namespace xsimd

#endif

namespace generator
{
template <class T>
struct iota
{
static constexpr T get(size_t index, size_t)
{
return static_cast<T>(index);
}
};
}
/**
* @brief Build a @c batch_constant as an enumerated range
*
* @tparam T type of the data held in the batch.
* @tparam A Architecture that will be used when converting to a regular batch.
*/
template <typename T, class A = default_arch>
XSIMD_INLINE constexpr auto make_iota_batch_constant() noexcept
{
return make_batch_constant<T, generator::iota<T>, A>();
}

} // namespace xsimd

#endif
31 changes: 31 additions & 0 deletions test/test_batch_constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ struct constant_batch_test
constexpr auto b = xsimd::make_batch_constant<value_type, arange, arch_type>();
INFO("batch(value_type)");
CHECK_BATCH_EQ((batch_type)b, expected);

constexpr auto b_p = xsimd::make_iota_batch_constant<value_type, arch_type>();
INFO("batch(value_type)");
CHECK_BATCH_EQ((batch_type)b_p, expected);
}

template <value_type V>
Expand All @@ -106,43 +110,64 @@ struct constant_batch_test
{
constexpr auto n12 = xsimd::make_batch_constant<value_type, constant<12>, arch_type>();
constexpr auto n3 = xsimd::make_batch_constant<value_type, constant<3>, arch_type>();
constexpr std::integral_constant<value_type, 3> c3;

constexpr auto n12_add_n3 = n12 + n3;
constexpr auto n15 = xsimd::make_batch_constant<value_type, constant<15>, arch_type>();
static_assert(std::is_same<decltype(n12_add_n3), decltype(n15)>::value, "n12 + n3 == n15");
constexpr auto n12_add_c3 = n12 + c3;
static_assert(std::is_same<decltype(n12_add_c3), decltype(n15)>::value, "n12 + c3 == n15");

constexpr auto n12_sub_n3 = n12 - n3;
constexpr auto n9 = xsimd::make_batch_constant<value_type, constant<9>, arch_type>();
static_assert(std::is_same<decltype(n12_sub_n3), decltype(n9)>::value, "n12 - n3 == n9");
constexpr auto n12_sub_c3 = n12 - c3;
static_assert(std::is_same<decltype(n12_sub_c3), decltype(n9)>::value, "n12 - c3 == n9");

constexpr auto n12_mul_n3 = n12 * n3;
constexpr auto n36 = xsimd::make_batch_constant<value_type, constant<36>, arch_type>();
static_assert(std::is_same<decltype(n12_mul_n3), decltype(n36)>::value, "n12 * n3 == n36");
constexpr auto n12_mul_c3 = n12 * c3;
static_assert(std::is_same<decltype(n12_mul_c3), decltype(n36)>::value, "n12 - c3 == n36");

constexpr auto n12_div_n3 = n12 / n3;
constexpr auto n4 = xsimd::make_batch_constant<value_type, constant<4>, arch_type>();
static_assert(std::is_same<decltype(n12_div_n3), decltype(n4)>::value, "n12 / n3 == n4");
constexpr auto n12_div_c3 = n12 / c3;
static_assert(std::is_same<decltype(n12_div_c3), decltype(n4)>::value, "n12 / c3 == n4");

constexpr auto n12_mod_n3 = n12 % n3;
constexpr auto n0 = xsimd::make_batch_constant<value_type, constant<0>, arch_type>();
static_assert(std::is_same<decltype(n12_mod_n3), decltype(n0)>::value, "n12 % n3 == n0");
constexpr auto n12_mod_c3 = n12 % c3;
static_assert(std::is_same<decltype(n12_mod_c3), decltype(n0)>::value, "n12 % c3 == n0");

constexpr auto n12_land_n3 = n12 & n3;
static_assert(std::is_same<decltype(n12_land_n3), decltype(n0)>::value, "n12 & n3 == n0");
constexpr auto n12_land_c3 = n12 & c3;
static_assert(std::is_same<decltype(n12_land_c3), decltype(n0)>::value, "n12 & c3 == n0");

constexpr auto n12_lor_n3 = n12 | n3;
static_assert(std::is_same<decltype(n12_lor_n3), decltype(n15)>::value, "n12 | n3 == n15");
constexpr auto n12_lor_c3 = n12 | c3;
static_assert(std::is_same<decltype(n12_lor_c3), decltype(n15)>::value, "n12 | c3 == n15");

constexpr auto n12_lxor_n3 = n12 ^ n3;
static_assert(std::is_same<decltype(n12_lxor_n3), decltype(n15)>::value, "n12 ^ n3 == n15");
constexpr auto n12_lxor_c3 = n12 ^ c3;
static_assert(std::is_same<decltype(n12_lxor_c3), decltype(n15)>::value, "n12 ^ c3 == n15");

constexpr auto n96 = xsimd::make_batch_constant<value_type, constant<96>, arch_type>();
constexpr auto n12_lshift_n3 = n12 << n3;
static_assert(std::is_same<decltype(n12_lshift_n3), decltype(n96)>::value, "n12 << n3 == n96");
constexpr auto n12_lshift_c3 = n12 << c3;
static_assert(std::is_same<decltype(n12_lshift_c3), decltype(n96)>::value, "n12 << c3 == n96");

constexpr auto n1 = xsimd::make_batch_constant<value_type, constant<1>, arch_type>();
constexpr auto n12_rshift_n3 = n12 >> n3;
static_assert(std::is_same<decltype(n12_rshift_n3), decltype(n1)>::value, "n12 >> n3 == n1");
constexpr auto n12_rshift_c3 = n12 >> c3;
static_assert(std::is_same<decltype(n12_rshift_c3), decltype(n1)>::value, "n12 >> c3 == n1");

constexpr auto n12_uadd = +n12;
static_assert(std::is_same<decltype(n12_uadd), decltype(n12)>::value, "+n12 == n12");
Expand All @@ -163,21 +188,27 @@ struct constant_batch_test

static_assert(std::is_same<decltype(n12 == n12), true_batch_type>::value, "n12 == n12");
static_assert(std::is_same<decltype(n12 == n3), false_batch_type>::value, "n12 == n3");
static_assert(std::is_same<decltype(n12 == c3), false_batch_type>::value, "n12 == c3");

static_assert(std::is_same<decltype(n12 != n12), false_batch_type>::value, "n12 != n12");
static_assert(std::is_same<decltype(n12 != n3), true_batch_type>::value, "n12 != n3");
static_assert(std::is_same<decltype(n12 != c3), true_batch_type>::value, "n12 != c3");

static_assert(std::is_same<decltype(n12 < n12), false_batch_type>::value, "n12 < n12");
static_assert(std::is_same<decltype(n12 < n3), false_batch_type>::value, "n12 < n3");
static_assert(std::is_same<decltype(n12 < c3), false_batch_type>::value, "n12 < c3");

static_assert(std::is_same<decltype(n12 > n12), false_batch_type>::value, "n12 > n12");
static_assert(std::is_same<decltype(n12 > n3), true_batch_type>::value, "n12 > n3");
static_assert(std::is_same<decltype(n12 > c3), true_batch_type>::value, "n12 > c3");

static_assert(std::is_same<decltype(n12 <= n12), true_batch_type>::value, "n12 <= n12");
static_assert(std::is_same<decltype(n12 <= n3), false_batch_type>::value, "n12 <= n3");
static_assert(std::is_same<decltype(n12 <= c3), false_batch_type>::value, "n12 <= c3");

static_assert(std::is_same<decltype(n12 >= n12), true_batch_type>::value, "n12 >= n12");
static_assert(std::is_same<decltype(n12 >= n3), true_batch_type>::value, "n12 >= n3");
static_assert(std::is_same<decltype(n12 >= c3), true_batch_type>::value, "n12 >= c3");
}
};

Expand Down