We start off by templating our kernel with parameters determined at run-time, and defining several constants.
template <int InputChannels, int InputLength, int Padding, int KernelSize,
int ChannelsPerThread>
__global__ void conv1d(float *d_input, float *d_weight, float *d_bias, float *d_output)
{
constexpr int SharedMemLength = constexpr_max(InputLength, KernelSize);
const int blockId = blockIdx.x;
const int tdIdx = threadIdx.x;
const int laneIdx = threadIdx.x % warpSize;
const int warpIdx = threadIdx.x / warpSize;
const int input_accesses_per_thread = (InputChannels * InputLength)/(4 * blockDim.x);
const int weight_accesses_per_thread = (InputChannels * KernelSize)/(blockDim.x);
const int weight_offset = blockId * InputChannels * KernelSize;
const int padded_input_length =
InputLength + Padding * 2;
const int shared_mem_offset_denom =
(InputLength * ChannelsPerThread) < 32 ? 32 : (InputLength * ChannelsPerThread);
}
Note that addition of a new template parameter, ‘ChannelsPerThread’. We search over different configurations of input tensor shape and channels per thread to find a value that minimizes run-time (more on this later). Next we define several constants for which I have included a natural language description below.
With these defined, we move onto static memory allocations of registers and shared memory.
//static mem allocations
float regInput[padded_input_length*ChannelsPerThread] = {0};
float regFilter[KernelSize*ChannelsPerThread];
__shared__ float shared_mem[InputChannels * SharedMemLength];
Not much is different here from the unoptimized kernel, except that we modify the register allocations to account for the possibility of having multiple channels per thread. There's actually a minor bug related to memory allocation for shared memory in the code above.
for (int channelIndex = 0; channelIndex < input_accesses_per_thread; ++channelIndex){
int td_offset = 4 * (channelIndex * blockDim.x + tdIdx);
int smem_offset = td_offset/shared_mem_offset_denom;
float4 data = *reinterpret_cast<float4*>(&d_input[td_offset]);
shared_mem[td_offset + smem_offset + 0] = data.x;
shared_mem[td_offset + smem_offset + 1] = data.y;
shared_mem[td_offset + smem_offset + 2] = data.z;
shared_mem[td_offset + smem_offset + 3] = data.w;
}
__syncthreads();