Skip to content

Microkernels

Chenhan D. Yu edited this page Nov 10, 2016 · 20 revisions

Microkernels are the only architecture dependent implementation in HMLP that have a huge impact on performance. An microkernel in HMLP usually computes an MR-by-NR semiring rank-KC update (a.k.a SEMIRINGKERNEL). Otherwise, addition to the rank-KC update it fuses an element-wise transformation even a row-wise reduction (a.k.a FUSEDKERNEL). These microkernels can be implemented either in high level languages, counting on compilers to optimize them for you. On the other hand, you can implement them in low level assembly if you do not trust your compilers can do a good job.

SEMIRINGKERNEL

This is ``the'' most important kernel that decide the performance. A reference kernel can be found in /hmlp/kernel/reference/semiring_mrxnr.hpp. The prototype looks like this

template<
int MR, int NR,
typename OP1, typename OP2,
typename TA, typename TB, typename TC, typename TV>
struct semiring_mrxnr
{
  OP1 op1;
  OP2 op2;
  TV initV;
  inline void operator()
  (
    int k,
    TA *a,
    TB *b,
    TV *v, int ldv,
    aux_s<TA, TB, TC, TV> *aux
  ) const
  {
    TV regV[ MR * NR ];

    if ( !aux->pc ) // Initialize
    {
      #pragma unroll
      for ( int j = 0; j < NR; j ++ )
        #pragma simd
        for ( int i = 0; i < MR; i ++ )
          regV[ j * MR + i ] = initV;
    }
    else // accumulate
    {
      #pragma unroll
      for ( int j = 0; j < NR; j ++ )
        #pragma simd
        for ( int i = 0; i < MR; i ++ )
          regV[ j * MR + i ] = v[ j * ldv + i ];
    }

    // semiring rank-k update
    for ( int p = 0; p < k; p ++ )
    {
      #pragma unroll
      for ( int j = 0; j < NR; j ++ )
        #pragma simd
        for ( int i = 0; i < MR; i ++ )
          regV[ j * MR + i ] =
            op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) );

    }

    // store back
    #pragma unroll
    for ( int j = 0; j < NR; j ++ )
      #pragma simd
      for ( int i = 0; i < MR; i ++ )
        v[ j * ldv + i ] = regV[ j * MR + i ];
  };
};

An easy way to understand the kernel above is to let op1=add and op2=mul. Then it computes v=ab if the indicator pc is zero. Otherwise it computes v+=ab instead. To help compilers optimize this kernel, we tell compiler to perform loop unrolling on the j dimension and use SIMD (Single Instruction Multiple Data) instruction on the i dimension. We know these optimizations are legal, since on only read after write dependency is in the p dimension (loop over k).

However, even we already told compilers many hints to optimize, but in our experience the performance you can get is still far from the best that an expert can achieve. For domain experts and performance ninjas, you may want to optimize a specific combination of <MR, NR, OP1, OP2, TA, TB, TC, TV> by your hand. We specify the general plug-and-play prototype of SEMIRINGKERNEL here.

struct semiring_(language)_(type)(mr)x(nr)
{
  inline void operator()
  (
    int k,
    TA *a,
    TB *b,
    TV *v, int ldv,
    aux_s<TA, TB, TC, TV> *aux
  ) const
  {
    if ( !aux->pc ) {}
    else            {}
  };
};

Developers can do whatever they want as long as the kernel computes the correct answer. For example, HMLP has an modified version of the BLIS GEMM microkernel at /hmlp/kernel/x86_64/sandybridge/rank_k_d8x4.hpp. One implementation is called rank_k_asm_d8x4. This means that this microkernel is written in assembly, double precision and with MR=8 and NR=4.

FUSEDKERNEL

This kernel is almost the same as SEMIRINGKERNEL, but it takes additional outputs c in type TC. Depending on the primitive you are writing kernels for, an additional element-wise transformation and even a row-wise (or column-wise) reduction will be performed in FUSEDKERNEL.

template<
int MR, int NR,
typename OPKERNEL, typename OP1, typename OP2, typename OPREDUCE,
typename TA, typename TB, typename TC, typename TV>
struct gkrm_mrxnr
{
  OPKERNEL opkernel;
  OP1 op1;
  OP2 op2;
  TV initV;
  OPKERNEL opkreduce;
  TC initC;

  inline void operator()
  (
    int k,
    TA *a,
    TB *b,
    TC *c, int ldc, // ldc is redundant here
    TV *v, int ldv,
    aux_s<TA, TB, TC, TV> *aux
  ) const
  {
    TV regV[ MR * NR ];
    TC regC[ MR ];

    if ( !aux->pc ) // Initialize
    {
      #pragma unroll
      for ( int j = 0; j < NR; j ++ )
        #pragma simd
        for ( int i = 0; i < MR; i ++ )
          regV[ j * MR + i ] = initV;
    }
    else // accumulate
    {
      #pragma unroll
      for ( int j = 0; j < NR; j ++ )
        #pragma simd
        for ( int i = 0; i < MR; i ++ )
          regV[ j * MR + i ] = v[ j * ldv + i ];
    }

    // semiring rank-k update
    for ( int p = 0; p < k; p ++ )
    {
      #pragma unroll
      for ( int j = 0; j < NR; j ++ )
        #pragma simd
        for ( int i = 0; i < MR; i ++ )
          regV[ j * MR + i ] =
            op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) );
    }

    // Initialize
    #pragma simd
    for ( int i = 0; i < MR; i ++ )
      regC[ i ] = initC;

    // kernel transformation and reduction
    #pragma unroll
    for ( int j = 0; j < NR; j ++ )
      #pragma simd
      for ( int i = 0; i < MR; i ++ )
        regC[ i ] = opreduce( regC[ i ], opkernel( regV[ j * MR + i ] ) );

    #pragma simd
    for ( int i = 0; i < MR; i ++ )
      c[ i ] = regC[ i ];

  };
};

The only thing to be careful is to accumulate the result from v when pc is not zero. In this case, you have already computed several SEMIRINGKERNELs before you call the FUSEDKERNEL. You need to accumulate the previous results to compute the correct rank-KC update. We also specify the general plug-and-play prototype of FUSEDKERNEL here.

struct (primitive)_(language)_(type)(mr)x(nr)
{
  inline void operator()
  (
    int k,
    TA *a,
    TB *b,
    TC *c, int ldc,
    TV *v, int ldv,
    aux_s<TA, TB, TC, TV> *aux
  ) const
  {
    if ( !aux->pc ) {}
    else            {}
  };
};

GSKS microkernel

GSKNN microkernel

Guideline for migrating BLIS kernel to HMLP kernel

Here I will explain how you can use existing BLIS microkernels in HMLP by a wrapper and some modifications to the interface. We take the sandybridge sgemm kernel bli_sgemm_asm_8x8 in bli_gemm_asm_d8x4.cpp as an example. The interface looks like:

void bli_sgemm_asm_8x8
     (
       dim_t               k,
       float*     restrict alpha,
       float*     restrict a,
       float*     restrict b,
       float*     restrict beta,
       float*     restrict c, inc_t rs_c, inc_t cs_c,
       auxinfo_t* restrict data,
       cntx_t*    restrict cntx
     )
{
	//void*   a_next = bli_auxinfo_next_a( data );
	//void*   b_next = bli_auxinfo_next_b( data );
};

We want to be free from all internal types used in BLIS. Thus, we replace auxinfo_t and cntx_t with our aux_s here. After replacing the data structure, the prototype should look like:

#include <hmlp_internal.hpp>

void bli_sgemm_asm_8x8
(
  dim_t               k,
  float*     restrict alpha,
  float*     restrict a,
  float*     restrict b,
  float*     restrict beta,
  //auxinfo_t* restrict data,
  //cntx_t*    restrict cntx
  float*     restrict c, inc_t rs_c, inc_t cs_c,
  aux_s<float, float, float, float> *aux
)
{
  //void*   a_next = bli_auxinfo_next_a( data );
  //void*   b_next = bli_auxinfo_next_b( data );
  void*   a_next = (void*)aux->a_next;
  void*   b_next = (void*)aux->b_next;   
};

while bli_gemm_asm_d8x4.cpp is put in the kernel/x86_64/sandybridge, it will be compiled. Now you can write your operator as a wrapper to call it. For example, the wrapper should look like:

void bli_sgemm_asm_8x8
(
  dim_t               k,
  float*     restrict alpha,
  float*     restrict a,
  float*     restrict b,
  float*     restrict beta,
  float*     restrict c, inc_t rs_c, inc_t cs_c,
  aux_s<float, float, float, float> *aux
);

struct rank_k_asm_s8x8
{
  // Strassen interface
  inline void operator()
  (
    int k,
    float *a,
    float *b,
    int len,
    float **c, int ldc, float *alpha,
    aux_s<float, float, float, float> *aux
  ) const
  {
    // Strassen implementation here.
  }; // end inline void operator()

  inline void operator()
  (
    int k,
    float *a,
    float *b,
    float *c, int ldc,
    aux_s<float, float, float, float> *aux
  ) const
  {
    float alpha = 1.0;
    float beta  = aux->pc ? 1.0 : 0.0;
    bli_sgemm_asm_8x4
    (
      k,
      &alpha,
      a,
      b,
      &beta,
      c, 1, ldc,
      aux
    );
  }; // end inline void operator()
};

Clone this wiki locally