We start off by templating our kernel with parameters determined at run-time, and defining several constants.

template <int InputChannels, int InputLength, int Padding, int KernelSize,  
int ChannelsPerThread>  
__global__ void conv1d(float *d_input, float *d_weight, float *d_bias, float *d_output)  
{  
   constexpr int SharedMemLength = constexpr_max(InputLength, KernelSize);  
   const int blockId = blockIdx.x;  
   const int tdIdx = threadIdx.x;  
   const int laneIdx = threadIdx.x % warpSize;  
   const int warpIdx = threadIdx.x / warpSize;  
   const int input_accesses_per_thread = (InputChannels * InputLength)/(4 * blockDim.x);  
   const int weight_accesses_per_thread = (InputChannels * KernelSize)/(blockDim.x);  
   const int weight_offset = blockId * InputChannels * KernelSize;  
   const int padded_input_length =  
   InputLength + Padding * 2;  
   const int shared_mem_offset_denom =  
   (InputLength * ChannelsPerThread) < 32 ? 32 : (InputLength * ChannelsPerThread);  
}

Note that addition of a new template parameter, ‘ChannelsPerThread’. We search over different configurations of input tensor shape and channels per thread to find a value that minimizes run-time (more on this later). Next we define several constants for which I have included a natural language description below.

With these defined, we move onto static memory allocations of registers and shared memory.

//static mem allocations
float regInput[padded_input_length*ChannelsPerThread] = {0};
float regFilter[KernelSize*ChannelsPerThread];
__shared__ float shared_mem[InputChannels * SharedMemLength];

Not much is different here from the unoptimized kernel, except that we modify the register allocations to account for the possibility of having multiple channels per thread. There's actually a minor bug related to memory allocation for shared memory in the code above.

for (int channelIndex = 0; channelIndex < input_accesses_per_thread; ++channelIndex){
    int td_offset = 4 * (channelIndex * blockDim.x + tdIdx);
    int smem_offset = td_offset/shared_mem_offset_denom;
    float4 data = *reinterpret_cast<float4*>(&d_input[td_offset]);
    shared_mem[td_offset + smem_offset + 0] = data.x;
    shared_mem[td_offset + smem_offset + 1] = data.y;
    shared_mem[td_offset + smem_offset + 2] = data.z;
    shared_mem[td_offset + smem_offset + 3] = data.w;
}
__syncthreads();