avx

SSE4 data types:

  • __m128, 4 floats

  • __m128d, 2 doubles

  • __m128i, it depends, can be 16 8-bit, 8 16-bit, 4 32-bit, 2 64-bit

AVX2 data types:

  • __m256, 8 floats

  • __m256d, 4 doubles

  • __m256i, 32 8-bt, 16 16-bit, 8 32-bit, 4 64-bit

./code/avx/main.cc
  1#include <cassert>
  2#include <immintrin.h>
  3
  4// ps means packed signle precision
  5static void TestLoadStore() {
  6  alignas(16) float a[4] = {1, 2, 3, 4};
  7  alignas(16) float b[4];
  8  __m128 f = _mm_load_ps(a);
  9  // f = _mm_loadu_ps(a); // if a not aligned
 10  _mm_store_ps(b, f);
 11  // _mm_storeu_ps(b, f); // if b is not aligned
 12  assert(b[0] == a[0]);
 13  assert(b[1] == a[1]);
 14  assert(b[2] == a[2]);
 15  assert(b[3] == a[3]);
 16
 17  // set manually
 18  f = _mm_set_ps(a[3], a[2], a[1], a[0]);
 19  _mm_store_ps(b, f);
 20  assert(b[0] == a[0]);
 21  assert(b[1] == a[1]);
 22  assert(b[2] == a[2]);
 23  assert(b[3] == a[3]);
 24
 25  // for double
 26  alignas(32) double k[4] = {1, 2, 3, 4};
 27  __m256d d = _mm256_load_pd(k);
 28  // d = _mm256_loadu_pd(k); // if k is not aligned
 29  alignas(32) double m[4];
 30  _mm256_store_pd(m, d);
 31  // _mm256_storeu_pd(m, d); // if m is not aligned
 32  assert(m[0] == k[0]);
 33  assert(m[1] == k[1]);
 34  assert(m[2] == k[2]);
 35  assert(m[3] == k[3]);
 36
 37  d = _mm256_set_pd(k[3], k[2], k[1], k[0]);
 38  _mm256_store_pd(m, d);
 39  assert(m[0] == k[0]);
 40  assert(m[1] == k[1]);
 41  assert(m[2] == k[2]);
 42  assert(m[3] == k[3]);
 43}
 44
 45static void TestLoadStore1() {
 46  float a = 10;
 47  float b[4];
 48  __m128 f = _mm_load_ps1(&a);
 49  _mm_store_ps(b, f);
 50  assert(b[0] == a);
 51  assert(b[1] == a);
 52  assert(b[2] == a);
 53  assert(b[3] == a);
 54}
 55
 56static void TestAdd() {
 57  float a[4] = {1, 2, 3, 4};
 58  float b[4] = {10, 20, 30, 40};
 59  __m128 f = _mm_load_ps(a);
 60  __m128 g = _mm_load_ps(b);
 61  __m128 h = _mm_add_ps(f, g);
 62  float c[4];
 63  _mm_store_ps(c, h);
 64  assert(c[0] == a[0] + b[0]);
 65  assert(c[1] == a[1] + b[1]);
 66  assert(c[2] == a[2] + b[2]);
 67  assert(c[3] == a[3] + b[3]);
 68}
 69
 70static void AddIndex1(double *x, int32_t n) {
 71  for (int32_t i = 0; i < n; ++i) {
 72    x[i] = x[i] + i;
 73  }
 74}
 75
 76// assume n % 4 == 0
 77static void AddIndex2(double *x, int32_t n) {
 78  assert(n % 4 == 0);
 79  __m256d index, x_vec;
 80  for (int32_t i = 0; i < n; i += 4) {
 81    x_vec = _mm256_load_pd(x + i);
 82    // x_vec[0] = x[i]
 83    // x_vec[1] = x[i+1]
 84    // x_vec[2] = x[i+2]
 85    // x_vec[3] = x[i+3]
 86
 87    index = _mm256_set_pd(i + 3, i + 2, i + 1, i);
 88    // index[0] = i
 89    // index[1] = i+1
 90    // index[2] = i+2
 91    // index[3] = i+3
 92
 93    x_vec = _mm256_add_pd(x_vec, index);
 94    // x_vec[0] = x_vec[0] + index[0]
 95    // x_vec[1] = x_vec[1] + index[1]
 96    // x_vec[2] = x_vec[2] + index[2]
 97    // x_vec[3] = x_vec[3] + index[3]
 98
 99    _mm256_store_pd(x + i, x_vec);
100    // (x+i)[0] = x_vec[0]
101    // (x+i)[1] = x_vec[1]
102    // (x+i)[2] = x_vec[2]
103    // (x+i)[3] = x_vec[3]
104  }
105}
106
107static void TestAddIndex() {
108  alignas(32) double a[64];
109  alignas(32) double b[64];
110  for (int32_t i = 0; i != 64; ++i) {
111    a[i] = b[i] = i;
112  }
113  AddIndex1(a, 64);
114  AddIndex2(b, 64);
115  for (int32_t i = 0; i != 64; ++i) {
116    assert(a[i] == b[i]);
117  }
118}
119
120int main() {
121  TestLoadStore();
122  TestLoadStore1();
123  TestAdd();
124  TestAddIndex();
125  return 0;
126}