Custom matrix multiplication kernels in C : A friendly introduction to FMA’s complement

Kibicho Murage
12 min readApr 2, 2024

--

Murage Kibicho — AI Researcher at FileForma

Summary Image

Optimized matrix multiplication kernels power modern artificial intelligence algorithms. This article introduces Fused Multiply Add’s complement or FMA’s complement. It is a custom binary number system designed to calculate vector dot products. Technically, FMA’s complement is a fixed-point representation of binary numbers created — as an alternative to 2’s complement and 1’s complement — to perform the Fused Multiply-Add instruction.

FMA’s complement is still in R&D. The algorithm is somewhat slow, however, this article guides programmers through a working implementation in the C language. Programmers who follow along shall witness 3 important results.

  1. Dot product alternative — FMA’s complement opens up research into alternative algorithms for calculating vector dot products.
  2. Arithmetic on compressed forms — FMA’s complement compresses matrices and permits matrix multiplications in a compressed state. Decoding is not necessary.
  3. Cache awareness — FMA’s complement caches intermediate dot product results. This reduces the total number of CPU instructions needed to perform a matrix multiplication.
  4. Linear Basis multiplications — FMA’s complement permits one to calculate few vector dot products and guess (with reasonable accuracy) the results of related vector dot products.

Fileforma is a startup dedicated to binary format research. Our current focus is on custom number systems for matrix multiplications. This tutorial is intended for programmers interested in writing their own FMA’s complement library. We are pre-selling commercial research licenses here. This guarantees you access to our research on matrix multiplication kernels.

Send feedback to murage@fileforma.com or DM me on X/Twitter @murage_kibicho . We raised a 3k round. I’m super grateful to our angels. Please reach out if you’re interested in our pre-seed round.

A Primer on Matrix Multiplications, Dot Products and Fused Multiply-Add

A matrix multiplication is the product of two matrices. At the very least, one can look at matrix multiplication from five different points of view. However, we direct our focus to the interpretation, A matrix multiplication is lots of individual dot products.

The dot product is an algebraic operation that takes two equal-length sequences of numbers and returns a single number. Ref. The image below visualizes the results of multiplying two 2*2 matrices. Here, C = A * B and the four dots (⋅) on the right represent four individual dot products.

Individual dot products constitute a matrix multiplication

Implementing a dot product involves both addition and multiplication. These are two different CPU instructions. Hardware vendors provide the Fused multiply-add SIMD intrinsic to perform both multiplication and addition in a single instruction.

Fused multiply-add instruction
ARM developer Fused multiply-add intrinsic

Building a FMA’s complement from Scratch

Three components constitute our working implementation of FMA’s complement.

  1. Fixed-point arithmetic for calculations.
  2. A binary arithmetic coder for compression.
  3. One lookup table

1. Performing Fixed-Point Arithmetic

You need this Fixed Point C Math library from Sourceforge to proceed. First, create a new folder on your computer. Download the file fixedptc.h from Sourceforge and place it the folder. Then create main.c within the same folder. Your working directory should resemble the image below.

Sample directory structure

We implement our matrices as 2-dimensional arrays. We use the fixedpt data type defined in fixedptc.h. Our current FMA’s implementation uses 32 bit fixed-point integers.

Inside main.c, we include all the libraries we need and write definitions for our fixed point library.

#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <assert.h>
#include <math.h>
#include <time.h>
#define FIXEDPT_BITS 32
#include "fixedptc.h"
#include <sys/time.h>

Next, we write a function to allocate a 2-dimensional array and fill the array with random elements within a range.

fixedpt **CreateRandomMatrix(int rowCount, int columnCount, float lowerBound, float upperBound)
{
fixedpt **result = malloc(rowCount*sizeof(fixedpt*));
for(int i = 0; i < rowCount; i++)
{
result[i] = calloc(columnCount,sizeof(fixedpt));
}
float randomFloat = 0.0f;
for(int i = 0; i < rowCount; i++)
{
for(int j = 0; j < columnCount; j++)
{
randomFloat = ((float)rand() / RAND_MAX) * (upperBound - lowerBound) + lowerBound;
result[i][j] = fixedpt_rconst(randomFloat);
}
}
return result;
}

We also write a function to initialize a matrix with zeros and another function to free our allocated arrays and finally, a function to print our matrices.

fixedpt **CreateZeroMatrix(int rowCount, int columnCount)
{
fixedpt **result = malloc(rowCount*sizeof(fixedpt*));
for(int i = 0; i < rowCount; i++)
{
result[i] = calloc(columnCount,sizeof(fixedpt));
}
return result;
}
void FreeMatrix(int rowCount, fixedpt **matrix)
{
for(int i = 0; i < rowCount; i++)
{
free(matrix[i]);
}
free(matrix);
}
void PrintMatrix(int rowCount, int columnCount, fixedpt **matrix)
{
for(int i = 0; i < rowCount; i++)
{
for(int j = 0; j < columnCount; j++)
{
PrintFixedPoint(matrix[i][j]);printf(",");
}
printf("\n");
}
printf("\n");
}

Finally we write a simple matrix multiplication algorithm. We loop through the rows and columns of matrix0 and matrix1 and store the result in result.

void NaiveFixedMatmul(int row0, int column0, int row1, int column1, fixedpt **mat0, fixedpt **mat1, fixedpt **result) 
{
assert(column0 == row1);
fixedpt temp = 0;
for(int i = 0; i < row0; i++)
{
for(int j = 0; j < column1; j++)
{
result[i][j] = 0;
for(int k = 0; k < column0; k++)
{
temp = fixedpt_mul(mat0[i][k], mat1[k][j]);
result[i][j] = fixedpt_add(result[i][j],temp);
}
}
}
}

We also include a main function and add code to measure our function’s running time. Your main.c file should resemble this.

//Compile: gcc main.c -lm -o main.o && ./main.o
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <assert.h>
#include <math.h>
#define FIXEDPT_BITS 32
#include "fixedptc.h"
#include <sys/time.h>
/*Step 1: Create a random matrix*/
fixedpt **CreateRandomMatrix(int rowCount, int columnCount, float lowerBound, float upperBound)
{
fixedpt **result = malloc(rowCount*sizeof(fixedpt*));
for(int i = 0; i < rowCount; i++)
{
result[i] = calloc(columnCount,sizeof(fixedpt));
}
float randomFloat = 0.0f;
for(int i = 0; i < rowCount; i++)
{
for(int j = 0; j < columnCount; j++)
{
randomFloat = ((float)rand() / RAND_MAX) * (upperBound - lowerBound) + lowerBound;
result[i][j] = fixedpt_rconst(randomFloat);
}
}
return result;
}

/*Step 2: Create a matrix filled with zeros*/
fixedpt **CreateZeroMatrix(int rowCount, int columnCount)
{
fixedpt **result = malloc(rowCount*sizeof(fixedpt*));
for(int i = 0; i < rowCount; i++)
{
result[i] = calloc(columnCount,sizeof(fixedpt));
}
return result;
}

/*Step 3 : Free our matrix*/
void FreeMatrix(int rowCount, fixedpt **matrix)
{
for(int i = 0; i < rowCount; i++)
{
free(matrix[i]);
}
free(matrix);
}

/*Step 4 : Print our matrix as a float*/
//Helper function to print fixedpoint number as floating point
void PrintFixedPoint(fixedpt A)
{
char num[20];
fixedpt_str(A, num, -2);
printf("%s",num);
}

void PrintMatrix(int rowCount, int columnCount, fixedpt **matrix)
{
for(int i = 0; i < rowCount; i++)
{
for(int j = 0; j < columnCount; j++)
{
PrintFixedPoint(matrix[i][j]);printf(",");
}
printf("\n");
}
printf("\n");
}
/*Step 5 : Simple Matrix multiplication
matrix0 * matrix1 = result
*/
void NaiveFixedMatmul(int row0, int column0, int row1, int column1, fixedpt **mat0, fixedpt **mat1, fixedpt **result)
{
assert(column0 == row1);
fixedpt temp = 0;
for(int i = 0; i < row0; i++)
{
for(int j = 0; j < column1; j++)
{
result[i][j] = 0;
for(int k = 0; k < column0; k++)
{
temp = fixedpt_mul(mat0[i][k], mat1[k][j]);
result[i][j] = fixedpt_add(result[i][j],temp);
}
}
}
}
/*Step 6: Test our function*/

int main()
{
/*Structs to hold current time*/
struct timeval startTime;struct timeval endTime;
double timeSpent = 0.0f;

/*Initialize our random number generator with a seed*/
srand(97656);
/*Set lower and upper bounds for the floats we generate*/
float lowerBound = -1.3;float upperBound = 1.3;
/*Set number of rows and number of columns for our matrices*/
int row0 = 3;
int column0 = 2;
int row1 = column0;
int column1 = 4;

/*Generate our matrices*/
fixedpt **matrix0 = CreateRandomMatrix(row0, column0,lowerBound, upperBound);
fixedpt **matrix1 = CreateRandomMatrix(row1, column1,lowerBound, upperBound);
fixedpt **result = CreateZeroMatrix(row0, column1);

/*Get Start time*/
gettimeofday(&startTime, NULL);
NaiveFixedMatmul(row0, column0, row1, column1, matrix0, matrix1, result);
/*Get End time*/
gettimeofday(&endTime, NULL);
/*Print our matrices*/
//PrintMatrix(row0, column0, matrix0);
//PrintMatrix(row1, column1, matrix1);
//PrintMatrix(row0, column1, result);

// Calculate the elapsed time in milliseconds
timeSpent = (endTime.tv_sec - startTime.tv_sec) * 1000.0; // Seconds to milliseconds
timeSpent += (endTime.tv_usec - startTime.tv_usec) / 1000.0; // Microseconds to milliseconds

printf("Time spent %.3f\n", timeSpent);

}
/*
1.05859375,0.77734375,
-0.359375,-0.1171875,
1.27734375,-0.45703125,

-0.44921875,-0.33984375,-1.2578125,-0.19140625,
-0.73046875,0.75,1.23046875,0.58203125,

-1.046875,0.21875,-0.37890625,0.24609375,
0.2421875,0.03125,0.3046875,-0.00390625,
-0.2421875,-0.78125,-2.171875,-0.515625,

Time spent 0.000
*/

2. Adding an Arithmetic Coder

At the heart of FMA’s complement is a Binary Arithmetic Coder. Arithmetic coding is a data compression technique where long messages are stored within a number interval. Here’s the original paper.

Arithmetic coding Original Paper Screenshot

We use the 16-bit interval [0–2¹⁶] for our arithmetic coder. First we create a new file called ArithmeticCoding.h and define some global variables. We need these global variables to cache our intermediate results.

uint32_t RANGE_LOW = 1; /*Minimum value in our range*/
uint32_t RANGE_HIGH = 0xffffffff; /*Maximum value in our range*/
uint32_t RANGE_CURRENT = 0; /*Current value in our range*/
unsigned char *byteHolder; /*Array to to hold encoded bytes*/
uint32_t byteHolderIndex = 0; /*Current index of byteHolderArray*/
uint32_t byteHolderLength = 0;/*Maximum length of byteHolderArray*/

Next, we write a function to reset our global variables.

void ResetEncoder()
{
RANGE_LOW = 1;
RANGE_HIGH = 0xffffffff;
RANGE_CURRENT = 0;
byteHolderIndex = 0;
}

Finally, we write the encoding function for our Arithmetic Coder. We take as input a bit (1 or 0) and a floating point probability. We encode this information into an array of bytes.

void Encode(int bit, float probability)
{
assert(probability >= 0.0f);
assert(probability <= 1.0f);
assert(bit == 0 || bit == 1);
assert(RANGE_HIGH > RANGE_LOW);
assert(RANGE_LOW >= 0);
int intProbability = (int) (probability * 65536.0f);
uint32_t RANGE_MID = RANGE_LOW +((RANGE_HIGH - RANGE_LOW) >> 16) * intProbability + ((((RANGE_HIGH - RANGE_LOW) & 0xffff) * intProbability) >> 16);
assert(RANGE_HIGH >= RANGE_MID);
assert(RANGE_MID >= RANGE_LOW);
if(bit){RANGE_HIGH = RANGE_MID;}else{RANGE_LOW = RANGE_MID + 1;}
while((RANGE_HIGH ^ RANGE_LOW) < 0x1000000)
{
assert(byteHolderIndex < byteHolderLength);
if(byteHolderIndex < byteHolderLength){byteHolder[byteHolderIndex]=RANGE_HIGH >> 24;}
byteHolderIndex += 1;
RANGE_HIGH = RANGE_HIGH <<8 | 255;
RANGE_LOW = RANGE_LOW << 8;
}
}

The complete ArithmeticCoder.h file should resemble this.

#include <stdlib.h>
#include <string.h>
#include <limits.h>
#include <assert.h>
#include <stdio.h>
#include <math.h>
#include <stdint.h>

uint32_t RANGE_LOW = 1; /*Minimum value in our range*/
uint32_t RANGE_HIGH = 0xffffffff; /*Maximum value in our range*/
uint32_t RANGE_CURRENT = 0; /*Current value in our range*/
unsigned char *byteHolder; /*Array to to hold encoded bytes*/
uint32_t byteHolderIndex = 0; /*Current index of byteHolderArray*/
uint32_t byteHolderLength = 0;/*Maximum length of byteHolderArray*/

/*Reset global variables*/
void ResetEncoder()
{
RANGE_LOW = 1;
RANGE_HIGH = 0xffffffff;
RANGE_CURRENT = 0;
byteHolderIndex = 0;
}
/*Encode a single bit*/
void Encode(int bit, float probability)
{
assert(probability >= 0.0f);
assert(probability <= 1.0f);
assert(bit == 0 || bit == 1);
assert(RANGE_HIGH > RANGE_LOW);
assert(RANGE_LOW >= 0);
int intProbability = (int) (probability * 65536.0f);
uint32_t RANGE_MID = RANGE_LOW +((RANGE_HIGH - RANGE_LOW) >> 16) * intProbability + ((((RANGE_HIGH - RANGE_LOW) & 0xffff) * intProbability) >> 16);
assert(RANGE_HIGH >= RANGE_MID);
assert(RANGE_MID >= RANGE_LOW);
if(bit){RANGE_HIGH = RANGE_MID;}else{RANGE_LOW = RANGE_MID + 1;}
while((RANGE_HIGH ^ RANGE_LOW) < 0x1000000)
{
assert(byteHolderIndex < byteHolderLength);
if(byteHolderIndex < byteHolderLength){byteHolder[byteHolderIndex]=RANGE_HIGH >> 24;}
byteHolderIndex += 1;
RANGE_HIGH = RANGE_HIGH <<8 | 255;
RANGE_LOW = RANGE_LOW << 8;
}
}

3. Putting everything together : Converting a vector to FMA’s complement

Now we are inside main.c. First, we write a function to extract bits from a 32bit integer. This function also counts the number of 1’s in a binary representation. We also include ArithmeticCoder.h from before.

#include "ArithmeticCoder.h"
/*Extract individual bits and count number of 1's*/
int IntToBinary(int32_t number, int oneCount, int totalCount)
{
int bit = 0;
int internalOneCount = 0;
for(int i = 31; i >= 0; i--)
{
bit = (number >> i) & 1;
internalOneCount += bit;
}
return internalOneCount;
}

Then we write a function to generate an FMA matrix multiplication lookup table. I create a struct called IntWithIndex and it holds the values we get while calling the Encode function. For brevity, I gloss over the IntWithIndex struct. Next, I modify the NaiveFixedMatmul function we wrote earlier into a GenerateTable function. For each matrix element, we find a binary representation and the probability of getting a 1 or a 0.

Please note, we need to allocate memory for byteHolder inside our main function.

/*IntWithIndex struct*/
typedef struct integer_index_struct IntWithIndex;struct integer_index_struct{int integer;int index;int *values; int valuesLength; int byteLength0;int byteLength;unsigned char *byteHolder;};
int IntWithIndexCompareIndex(const void *a, const void *b) {return ((*(IntWithIndex *)a).index - (*(IntWithIndex *)b).index);}
int IntWithIndexCompareInteger(const void *a, const void *b) {return ((*(IntWithIndex *)a).integer - (*(IntWithIndex *)b).integer);}
int IntWithIndexCompareIntegerOpposite(const void *a, const void *b) {return ((*(IntWithIndex *)b).integer - (*(IntWithIndex *)a).integer);}
int IntWithIndexCompareByteHolder(const void *a, const void *b)
{
const IntWithIndex *struct_a = (const IntWithIndex *)a;
const IntWithIndex *struct_b = (const IntWithIndex *)b;
//if(struct_a->byteLength != struct_b->byteLength){return struct_a->byteLength - struct_b->byteLength;}

for (int i = 0; i < struct_a->byteLength && i < struct_b->byteLength; ++i) {
if (struct_a->byteHolder[i] != struct_b->byteHolder[i]) {
return struct_a->byteHolder[i] - struct_b->byteHolder[i];
}
}

// If byteHolder elements are equal, compare the lengths
return 0;
}

int IntWithIndexCompareValues(const void *a, const void *b)
{
const IntWithIndex *struct_a = (const IntWithIndex *)a;
const IntWithIndex *struct_b = (const IntWithIndex *)b;
//if(struct_a->byteLength != struct_b->byteLength){return struct_a->byteLength - struct_b->byteLength;}

for(int i = 0; i < struct_a->valuesLength && i < struct_b->valuesLength; ++i){if(struct_a->values[i] != struct_b->values[i]){return struct_a->values[i] - struct_b->values[i];}}

// If byteHolder elements are equal, compare the lengths
return 0;
}
void PrintIntWithIndex(int length, IntWithIndex *array)
{
for(int i = 0; i < length; i++)
{
printf("%5d : ",array[i].integer);
for(int j = 0; j < array[i].valuesLength; j++)
{
printf("%4d,", array[i].values[j]);
}
printf(" | ");
for(int j = array[i].byteLength0; j < array[i].byteLength; j++)
{
printf("%3u,", array[i].byteHolder[j]);
}
printf("\n");
}
printf("\n");
}
IntWithIndex *CreateIntWithIndex(int length,int valuesLength)
{
IntWithIndex *array = malloc(length * sizeof(*array));
for(int i = 0; i < length; i++){array[i].integer = 0;array[i].index = 0;array[i].valuesLength = valuesLength;array[i].values = calloc(valuesLength,sizeof(int));array[i].byteLength0 = 0;array[i].byteLength = 0;array[i].byteHolder = calloc(byteHolderLength,sizeof(unsigned char));}
return array;
}

void DestroyIntWithIndex(int length, IntWithIndex *array){for(int i = 0; i < length; i++){free(array[i].values);free(array[i].byteHolder);}free(array);}

/*Function to Generate Lookup Table*/
void GenerateTable(int row0, int column0, int row1, int column1, fixedpt **mat0, fixedpt **mat1, fixedpt **result)
{
IntWithIndex *tempIndex = CreateIntWithIndex(row0 * column1, column0+row1);
assert(byteHolder != NULL);
assert(column0 == row1);
fixedpt temp = 0;
int currentIndex = 0;
int valueIndex = 0;
int oneCount = 1;
int totalCount = 2;
for(int i = 0; i < row0; i++)
{
for(int j = 0; j < column1; j++)
{
result[i][j] = 0;
valueIndex = 0;
//printf("%4d : ", currentIndex);
tempIndex[currentIndex].index = currentIndex;
for(int k = 0; k < column0; k++)
{
//printf("%4d %4d ",mat0[i][k], mat1[k][j]);
temp = fixedpt_mul(mat0[i][k], mat1[k][j]);
result[i][j] = fixedpt_add(result[i][j],temp);
tempIndex[currentIndex].values[valueIndex] = mat0[i][k];
tempIndex[currentIndex].values[valueIndex+1] = mat1[k][j];
valueIndex += 2;
}
tempIndex[currentIndex].integer = result[i][j];
ResetEncoder();
oneCount = 1;
totalCount = 2;
for(int k = 0; k < valueIndex; k++)
{
//printf("%5d ", tempIndex[currentIndex].values[k]);
oneCount += IntToBinary(tempIndex[currentIndex].values[k], oneCount, totalCount);
totalCount += 32;
}
tempIndex[currentIndex].byteLength = byteHolderIndex;
for(int k = 0; k < byteHolderIndex; k++)
{
tempIndex[currentIndex].byteHolder[k] = byteHolder[k];
}
//printf(" | %4d\n", result[i][j]);
currentIndex += 1;
}
}
qsort(tempIndex,row0 * column1, sizeof(IntWithIndex), IntWithIndexCompareByteHolder);
PrintIntWithIndex(row0 * column1,tempIndex);
DestroyIntWithIndex(row0 * column1,tempIndex);
}

Your directory should resemble this GitHub repo and your console should print out these results.

Sample results from running the GenerateTable function

Final : Interpreting the Lookup table

Our lookup table prints out results in the form (Dot product result : Vectors involved | Encoded Length(in bytes) Encoded Vector)

Let’s analyze the first line of our lookup table.

  • 1 : -92, -49, -30, 149, | 14 9,209,101, 59, 98, 92, 40, 29,184, 36,255,255,255,247,

This tells us the dot product of the two vectors (-92, -30) and (-49,149) equals 1 in fixed point arithmetic.

Dot product of the two vectors (-92, -30) and (-49,149) equals 1

We also learn that the vectors (-92, -30) and (-49,149) can be represented as the 14-byte sequence 9,209,101, 59, 98, 92, 40, 29,184, 36,255,255,255,247

In part 2, we analyze the lookup table. There are lots of patterns to observe. You can leave a comment or email me — murage@fileforma.com or DM me on Twitter murage_kibicho

  • *Spoiler Alert**

Part 2 is coming. These are the patterns to observe in the mean time.

  1. (-92, -30) and (-49,149) should be 16 bytes. Our encoded representation takes 14 bytes. We can perform arithmetic on compressed integers.
  2. We arranged our lookup table by Byte value. For some unknown reason, our vectors are organized — notice all the (-92,-30) vectors take the first 4 slots and the (327,-117) values come next. In part 2 we change the order of our for loops and take advantage of this natural ordering. We can cache these results.
  3. Observe that we can guess (with reasonable accuracy) the scalar between two calculated vectors. For instance, if we calculate these two vector products (-92, -30) , (-87,192) and (-92, -30) , (-322,315) then we can guess the result of the vector product (-92, -30) , (-115,-187). We can form somewhat of a basis set and use this set to calculate similar vector products.
We can estimate a vector product if we know the results of the vector product above and below.

--

--

Kibicho Murage
Kibicho Murage

Written by Kibicho Murage

AI Researcher at Fileforma. Twitter : murage_kibicho

No responses yet