@@ -178,10 +178,16 @@ def _check_type(
178178 mCuSeqlensK_type : Type [cutlass .Numeric ] | None ,
179179 mSeqUsedQ_type : Type [cutlass .Numeric ] | None ,
180180 mSeqUsedK_type : Type [cutlass .Numeric ] | None ,
181+ is_split_kv : bool = False ,
181182 ):
182- # Get the data type and check if it is fp16 or bf16
183- if const_expr (not (mQ_type == mK_type == mV_type == mO_type )):
184- raise TypeError ("All tensors must have the same data type" )
183+ if is_split_kv :
184+ if const_expr (not (mQ_type == mK_type == mV_type )):
185+ raise TypeError ("Q, K, V tensors must have the same data type" )
186+ if const_expr (mO_type != Float32 ):
187+ raise TypeError ("O tensor must be Float32 for split_kv" )
188+ else :
189+ if const_expr (not (mQ_type == mK_type == mV_type == mO_type )):
190+ raise TypeError ("All tensors must have the same data type" )
185191 if const_expr (mQ_type not in [cutlass .Float16 , cutlass .BFloat16 ]):
186192 raise TypeError ("Only Float16 or BFloat16 is supported" )
187193 if const_expr (mLSE_type not in [None , Float32 ]):
@@ -336,30 +342,33 @@ def epilogue(
336342 m_block : Int32 ,
337343 head_idx : Int32 ,
338344 batch_idx : Int32 ,
345+ split_idx : Int32 = Int32 (0 ),
339346 ):
340- # store acc_O
341- rO = cute .make_fragment_like (acc_O , self .dtype )
342- rO .store (acc_O .load ().to (self .dtype ))
343- # Make sure all threads have finished reading V
344- cute .arch .barrier (
345- barrier_id = int (NamedBarrierFwd .Epilogue ), number_of_threads = self .num_epilogue_threads
346- )
347- smem_copy_atom_O = utils .get_smem_store_atom (self .arch .major * 10 + self .arch .minor , self .dtype )
348- smem_thr_copy_O = cute .make_tiled_copy_C (smem_copy_atom_O , tiled_mma ).get_slice (tidx )
349- taccOrO = smem_thr_copy_O .retile (rO )
350- taccOsO = smem_thr_copy_O .partition_D (sO )
351- # taccOsO = copy_utils.partition_D_position_independent(smem_thr_copy_O, sO)
352- # copy acc O from rmem to smem with the smem copy atom
353- cute .copy (smem_copy_atom_O , taccOrO , taccOsO )
354-
355347 cO = cute .make_identity_tensor ((self .tile_m , self .tile_hdimv ))
356348 pack_gqa = PackGQA (
357349 self .tile_m , self .tile_hdimv , self .check_hdim_v_oob , self .qhead_per_kvhead
358350 )
359351
352+ if const_expr (not self .is_split_kv ):
353+ rO = cute .make_fragment_like (acc_O , self .dtype )
354+ rO .store (acc_O .load ().to (self .dtype ))
355+ # Make sure all threads have finished reading V
356+ cute .arch .barrier (
357+ barrier_id = int (NamedBarrierFwd .Epilogue ), number_of_threads = self .num_epilogue_threads
358+ )
359+ smem_copy_atom_O = utils .get_smem_store_atom (self .arch .major * 10 + self .arch .minor , self .dtype )
360+ smem_thr_copy_O = cute .make_tiled_copy_C (smem_copy_atom_O , tiled_mma ).get_slice (tidx )
361+ taccOrO = smem_thr_copy_O .retile (rO )
362+ taccOsO = smem_thr_copy_O .partition_D (sO )
363+ # copy acc O from rmem to smem with the smem copy atom
364+ cute .copy (smem_copy_atom_O , taccOrO , taccOsO )
365+
360366 # Write LSE from rmem -> gmem
361367 if const_expr (mLSE is not None ):
362- mLSE_cur = seqlen .offset_batch_Q (mLSE , batch_idx , dim = 2 )[None , head_idx ]
368+ if const_expr (self .is_split_kv ):
369+ mLSE_cur = seqlen .offset_batch_Q (mLSE , batch_idx , dim = 2 )[None , head_idx , split_idx ]
370+ else :
371+ mLSE_cur = seqlen .offset_batch_Q (mLSE , batch_idx , dim = 2 )[None , head_idx ]
363372 if const_expr (not self .pack_gqa ):
364373 gLSE = cute .local_tile (mLSE_cur , (self .tile_m ,), (m_block ,))
365374 gLSE_expanded_layout = cute .append (
@@ -383,63 +392,88 @@ def epilogue(
383392 pack_gqa .store_LSE (mLSE_cur , lse , tiled_mma , tidx , m_block , seqlen .seqlen_q )
384393
385394 ragged = self .use_tma_O and (seqlen .has_cu_seqlens_q or seqlen .has_seqused_q )
386- mO_cur = seqlen .offset_batch_Q (mO , batch_idx , dim = 3 , ragged = ragged )[None , None , head_idx ]
387- # thr_mma = tiled_mma.get_slice(tidx)
388- # taccOgO = thr_mma.partition_C(gO)
389- # cute.autovec_copy(rO, taccOgO)
390- # sync to make sure all smem stores are done
391- if const_expr (self .use_tma_O ):
392- # ensure smem writes are visible to TMA
393- cute .arch .fence_view_async_shared ()
394- cute .arch .barrier_arrive (
395- barrier_id = int (NamedBarrierFwd .Epilogue ),
396- number_of_threads = self .num_epilogue_threads + cute .arch .WARP_SIZE ,
397- )
398- gO = cute .local_tile (mO_cur , (self .tile_m , self .tile_hdimv ), (m_block , 0 ))
399- store_O , _ , _ = copy_utils .tma_get_copy_fn (
400- tma_atom_O , 0 , cute .make_layout (1 ), sO , gO , single_stage = True
401- )
402- warp_idx = cute .arch .make_warp_uniform (cute .arch .warp_idx ())
403- if warp_idx == 4 :
404- cute .arch .barrier (
405- barrier_id = int (NamedBarrierFwd .Epilogue ),
406- number_of_threads = self .num_epilogue_threads + cute .arch .WARP_SIZE ,
407- )
408- store_O ()
409- cute .arch .cp_async_bulk_commit_group ()
410- cute .arch .cp_async_bulk_wait_group (0 , read = True )
395+ if const_expr (self .is_split_kv ):
396+ mO_cur = seqlen .offset_batch_Q (mO , batch_idx , dim = 3 )[None , None , head_idx , split_idx ]
411397 else :
398+ mO_cur = seqlen .offset_batch_Q (mO , batch_idx , dim = 3 , ragged = ragged )[None , None , head_idx ]
399+
400+ if const_expr (self .is_split_kv ):
412401 cute .arch .barrier (
413- barrier_id = int (NamedBarrierFwd .Epilogue ),
414- number_of_threads = self .num_epilogue_threads ,
402+ barrier_id = int (NamedBarrierFwd .Epilogue ), number_of_threads = self .num_epilogue_threads
415403 )
416- gmem_thr_copy_O = gmem_tiled_copy_O .get_slice (tidx )
417- tOsO = gmem_thr_copy_O .partition_S (sO )
418- tOrO = cute .make_fragment_like (tOsO , self .dtype )
419- # load acc O from smem to rmem for wider vectorization
420- cute .autovec_copy (tOsO , tOrO )
421404 if const_expr (not self .pack_gqa ):
422405 gO = cute .local_tile (mO_cur , (self .tile_m , self .tile_hdimv ), (m_block , 0 ))
423- tOgO = gmem_thr_copy_O .partition_D (gO )
424- tOcO = gmem_thr_copy_O .partition_S (cO )
425- t0OcO = gmem_tiled_copy_O .get_slice (0 ).partition_S (cO )
426- tOpO = utils .predicate_k (tOcO , limit = mO .shape [1 ])
427- # copy acc O from rmem to gmem
428- for rest_m in cutlass .range_constexpr (cute .size (tOrO .shape [1 ])):
429- if (
430- t0OcO [0 , rest_m , 0 ][0 ]
431- < seqlen .seqlen_q - m_block * self .tile_m - tOcO [0 ][0 ]
432- ):
433- cute .copy (
434- gmem_tiled_copy_O ,
435- tOrO [None , rest_m , None ],
436- tOgO [None , rest_m , None ],
437- pred = tOpO [None , rest_m , None ]
438- if const_expr (self .check_hdim_v_oob )
439- else None ,
440- )
406+ thr_mma = tiled_mma .get_slice (tidx )
407+ taccOgO = layout_utils .reshape_acc_to_mn (thr_mma .partition_C (gO ))
408+ taccOcO = layout_utils .reshape_acc_to_mn (thr_mma .partition_C (cO ))
409+ taccOrO = layout_utils .reshape_acc_to_mn (acc_O )
410+ seqlen_q_limit = seqlen .seqlen_q - m_block * self .tile_m
411+ for k in cutlass .range_constexpr (cute .size (taccOrO .shape [0 ])):
412+ if taccOcO [k , 0 ][0 ] < seqlen_q_limit :
413+ for m in cutlass .range_constexpr (cute .size (taccOrO .shape [1 ])):
414+ if const_expr (not self .check_hdim_v_oob ) or taccOcO [k , m ][1 ] < mO .shape [1 ]:
415+ taccOgO [k , m ] = taccOrO [k , m ]
441416 else :
442- pack_gqa .store_O (mO_cur , tOrO , gmem_tiled_copy_O , tidx , m_block , seqlen .seqlen_q )
417+ # mO_gqa is ((qheads_per_kvhead, seqlen_q), d, h_kv)
418+ if const_expr (not seqlen .has_cu_seqlens_q ):
419+ mO_gqa = mO [None , None , None , batch_idx , split_idx ]
420+ else :
421+ offset = (0 , seqlen .offset_q )
422+ mO_gqa = cute .domain_offset ((offset , 0 , 0 ), mO [None , None , None , split_idx ])
423+ pack_gqa .store_O_splitkv (mO_gqa , acc_O , tiled_mma , tidx , m_block , seqlen .seqlen_q , head_idx )
424+ else :
425+ if const_expr (self .use_tma_O ):
426+ # ensure smem writes are visible to TMA
427+ cute .arch .fence_view_async_shared ()
428+ cute .arch .barrier_arrive (
429+ barrier_id = int (NamedBarrierFwd .Epilogue ),
430+ number_of_threads = self .num_epilogue_threads + cute .arch .WARP_SIZE ,
431+ )
432+ gO = cute .local_tile (mO_cur , (self .tile_m , self .tile_hdimv ), (m_block , 0 ))
433+ store_O , _ , _ = copy_utils .tma_get_copy_fn (
434+ tma_atom_O , 0 , cute .make_layout (1 ), sO , gO , single_stage = True
435+ )
436+ warp_idx = cute .arch .make_warp_uniform (cute .arch .warp_idx ())
437+ if warp_idx == 4 :
438+ cute .arch .barrier (
439+ barrier_id = int (NamedBarrierFwd .Epilogue ),
440+ number_of_threads = self .num_epilogue_threads + cute .arch .WARP_SIZE ,
441+ )
442+ store_O ()
443+ cute .arch .cp_async_bulk_commit_group ()
444+ cute .arch .cp_async_bulk_wait_group (0 , read = True )
445+ else :
446+ cute .arch .barrier (
447+ barrier_id = int (NamedBarrierFwd .Epilogue ),
448+ number_of_threads = self .num_epilogue_threads ,
449+ )
450+ gmem_thr_copy_O = gmem_tiled_copy_O .get_slice (tidx )
451+ tOsO = gmem_thr_copy_O .partition_S (sO )
452+ tOrO = cute .make_fragment_like (tOsO , self .dtype )
453+ # load acc O from smem to rmem for wider vectorization
454+ cute .autovec_copy (tOsO , tOrO )
455+ if const_expr (not self .pack_gqa ):
456+ gO = cute .local_tile (mO_cur , (self .tile_m , self .tile_hdimv ), (m_block , 0 ))
457+ tOgO = gmem_thr_copy_O .partition_D (gO )
458+ tOcO = gmem_thr_copy_O .partition_S (cO )
459+ t0OcO = gmem_tiled_copy_O .get_slice (0 ).partition_S (cO )
460+ tOpO = utils .predicate_k (tOcO , limit = mO .shape [1 ])
461+ # copy acc O from rmem to gmem
462+ for rest_m in cutlass .range_constexpr (cute .size (tOrO .shape [1 ])):
463+ if (
464+ t0OcO [0 , rest_m , 0 ][0 ]
465+ < seqlen .seqlen_q - m_block * self .tile_m - tOcO [0 ][0 ]
466+ ):
467+ cute .copy (
468+ gmem_tiled_copy_O ,
469+ tOrO [None , rest_m , None ],
470+ tOgO [None , rest_m , None ],
471+ pred = tOpO [None , rest_m , None ]
472+ if const_expr (self .check_hdim_v_oob )
473+ else None ,
474+ )
475+ else :
476+ pack_gqa .store_O (mO_cur , tOrO , gmem_tiled_copy_O , tidx , m_block , seqlen .seqlen_q )
443477
444478 @cute .jit
445479 def advance_pipeline (self , pipeline_index ):
0 commit comments