Skip to content

Questions about the implementation #131

@Glinttsd

Description

@Glinttsd

Hi, thanks for your great work!

I am trying to exploit how the TP kernel is implemented. I noticed that there is a CUDA file named subkernel_per_interaction_multirep.cuh (which is attached here) generated during the execution. I believe that is where the GPU kernel is implemented.

Regarding this CUDA file, I have two questions for its forward process:

  1. The sparse computation part to compute the CG coefficients (Algorithm 2 in the paper) seems not unrolled. The source code is provided below. The for loop is not tagged with #pragma unroll to unroll the loop.
  // PERFORM CG DECOMPOSITION CALCULATION 
  {%- for i in range(tensor.nnz) %}
      {%- set coord1, coord2, coord3, value = tensor.tuples[i] %}
      L3_local_vec[{{coord3}}] += {{value}} * {{instructions[interaction_index].path_weight}} * L1_local_vec[{{coord1}}] * L2_local_vec[{{coord2}}];
  {%- endfor %}
  1. The calculated results L3_local_vec from the above code seems been written into the shared memory smem_gemm_L3, and then multiplies the weight matrix. It seems to be contrary to the paper's description:

Finally, the output Z is accumulated to shared memory after multiplication by W .

The source code is provided below.

// WRITE TO SMEM_GEMM_L3 
#pragma unroll 
for(int L3_irrep_index = 0; L3_irrep_index < {{L3.irrep_lengths[w]}}; L3_irrep_index++){
    smem_gemm_L3[(threadIdx.x * {{L3.irrep_lengths[w]}}) + L3_irrep_index] = L3_local_vec[L3_irrep_index]; 
}

// WAIT FOR WEIGHTS TO HAVE ARRIVED
cooperative_groups::wait(group); 

group.sync();

// PERFORM MATMUL 
int i = threadIdx.x;
if(i < L3_mults){
    (some calculations)
}
group.sync();

Any help would be greatly appreciated! Thank you again for the great work!

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions