Skip to content

Commit f5bc33c

Browse files
Merge pull request #130 from IwakuraRein/enable-hdim-512-sm90-vllm
Enable hdim 512 sm90 vllm
2 parents c0ec424 + 54eaa3b commit f5bc33c

7 files changed

Lines changed: 211 additions & 104 deletions

File tree

flash_attn/cute/flash_fwd.py

Lines changed: 104 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,16 @@ def _check_type(
178178
mCuSeqlensK_type: Type[cutlass.Numeric] | None,
179179
mSeqUsedQ_type: Type[cutlass.Numeric] | None,
180180
mSeqUsedK_type: Type[cutlass.Numeric] | None,
181+
is_split_kv: bool = False,
181182
):
182-
# Get the data type and check if it is fp16 or bf16
183-
if const_expr(not (mQ_type == mK_type == mV_type == mO_type)):
184-
raise TypeError("All tensors must have the same data type")
183+
if is_split_kv:
184+
if const_expr(not (mQ_type == mK_type == mV_type)):
185+
raise TypeError("Q, K, V tensors must have the same data type")
186+
if const_expr(mO_type != Float32):
187+
raise TypeError("O tensor must be Float32 for split_kv")
188+
else:
189+
if const_expr(not (mQ_type == mK_type == mV_type == mO_type)):
190+
raise TypeError("All tensors must have the same data type")
185191
if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]):
186192
raise TypeError("Only Float16 or BFloat16 is supported")
187193
if const_expr(mLSE_type not in [None, Float32]):
@@ -336,30 +342,33 @@ def epilogue(
336342
m_block: Int32,
337343
head_idx: Int32,
338344
batch_idx: Int32,
345+
split_idx: Int32 = Int32(0),
339346
):
340-
# store acc_O
341-
rO = cute.make_fragment_like(acc_O, self.dtype)
342-
rO.store(acc_O.load().to(self.dtype))
343-
# Make sure all threads have finished reading V
344-
cute.arch.barrier(
345-
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads
346-
)
347-
smem_copy_atom_O = utils.get_smem_store_atom(self.arch.major * 10 + self.arch.minor, self.dtype)
348-
smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx)
349-
taccOrO = smem_thr_copy_O.retile(rO)
350-
taccOsO = smem_thr_copy_O.partition_D(sO)
351-
# taccOsO = copy_utils.partition_D_position_independent(smem_thr_copy_O, sO)
352-
# copy acc O from rmem to smem with the smem copy atom
353-
cute.copy(smem_copy_atom_O, taccOrO, taccOsO)
354-
355347
cO = cute.make_identity_tensor((self.tile_m, self.tile_hdimv))
356348
pack_gqa = PackGQA(
357349
self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead
358350
)
359351

352+
if const_expr(not self.is_split_kv):
353+
rO = cute.make_fragment_like(acc_O, self.dtype)
354+
rO.store(acc_O.load().to(self.dtype))
355+
# Make sure all threads have finished reading V
356+
cute.arch.barrier(
357+
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads
358+
)
359+
smem_copy_atom_O = utils.get_smem_store_atom(self.arch.major * 10 + self.arch.minor, self.dtype)
360+
smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx)
361+
taccOrO = smem_thr_copy_O.retile(rO)
362+
taccOsO = smem_thr_copy_O.partition_D(sO)
363+
# copy acc O from rmem to smem with the smem copy atom
364+
cute.copy(smem_copy_atom_O, taccOrO, taccOsO)
365+
360366
# Write LSE from rmem -> gmem
361367
if const_expr(mLSE is not None):
362-
mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2)[None, head_idx]
368+
if const_expr(self.is_split_kv):
369+
mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2)[None, head_idx, split_idx]
370+
else:
371+
mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2)[None, head_idx]
363372
if const_expr(not self.pack_gqa):
364373
gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,))
365374
gLSE_expanded_layout = cute.append(
@@ -383,63 +392,88 @@ def epilogue(
383392
pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q)
384393

385394
ragged = self.use_tma_O and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)
386-
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=ragged)[None, None, head_idx]
387-
# thr_mma = tiled_mma.get_slice(tidx)
388-
# taccOgO = thr_mma.partition_C(gO)
389-
# cute.autovec_copy(rO, taccOgO)
390-
# sync to make sure all smem stores are done
391-
if const_expr(self.use_tma_O):
392-
# ensure smem writes are visible to TMA
393-
cute.arch.fence_view_async_shared()
394-
cute.arch.barrier_arrive(
395-
barrier_id=int(NamedBarrierFwd.Epilogue),
396-
number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE,
397-
)
398-
gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0))
399-
store_O, _, _ = copy_utils.tma_get_copy_fn(
400-
tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True
401-
)
402-
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
403-
if warp_idx == 4:
404-
cute.arch.barrier(
405-
barrier_id=int(NamedBarrierFwd.Epilogue),
406-
number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE,
407-
)
408-
store_O()
409-
cute.arch.cp_async_bulk_commit_group()
410-
cute.arch.cp_async_bulk_wait_group(0, read=True)
395+
if const_expr(self.is_split_kv):
396+
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
411397
else:
398+
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=ragged)[None, None, head_idx]
399+
400+
if const_expr(self.is_split_kv):
412401
cute.arch.barrier(
413-
barrier_id=int(NamedBarrierFwd.Epilogue),
414-
number_of_threads=self.num_epilogue_threads,
402+
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads
415403
)
416-
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
417-
tOsO = gmem_thr_copy_O.partition_S(sO)
418-
tOrO = cute.make_fragment_like(tOsO, self.dtype)
419-
# load acc O from smem to rmem for wider vectorization
420-
cute.autovec_copy(tOsO, tOrO)
421404
if const_expr(not self.pack_gqa):
422405
gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0))
423-
tOgO = gmem_thr_copy_O.partition_D(gO)
424-
tOcO = gmem_thr_copy_O.partition_S(cO)
425-
t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
426-
tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])
427-
# copy acc O from rmem to gmem
428-
for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
429-
if (
430-
t0OcO[0, rest_m, 0][0]
431-
< seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0]
432-
):
433-
cute.copy(
434-
gmem_tiled_copy_O,
435-
tOrO[None, rest_m, None],
436-
tOgO[None, rest_m, None],
437-
pred=tOpO[None, rest_m, None]
438-
if const_expr(self.check_hdim_v_oob)
439-
else None,
440-
)
406+
thr_mma = tiled_mma.get_slice(tidx)
407+
taccOgO = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(gO))
408+
taccOcO = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cO))
409+
taccOrO = layout_utils.reshape_acc_to_mn(acc_O)
410+
seqlen_q_limit = seqlen.seqlen_q - m_block * self.tile_m
411+
for k in cutlass.range_constexpr(cute.size(taccOrO.shape[0])):
412+
if taccOcO[k, 0][0] < seqlen_q_limit:
413+
for m in cutlass.range_constexpr(cute.size(taccOrO.shape[1])):
414+
if const_expr(not self.check_hdim_v_oob) or taccOcO[k, m][1] < mO.shape[1]:
415+
taccOgO[k, m] = taccOrO[k, m]
441416
else:
442-
pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q)
417+
# mO_gqa is ((qheads_per_kvhead, seqlen_q), d, h_kv)
418+
if const_expr(not seqlen.has_cu_seqlens_q):
419+
mO_gqa = mO[None, None, None, batch_idx, split_idx]
420+
else:
421+
offset = (0, seqlen.offset_q)
422+
mO_gqa = cute.domain_offset((offset, 0, 0), mO[None, None, None, split_idx])
423+
pack_gqa.store_O_splitkv(mO_gqa, acc_O, tiled_mma, tidx, m_block, seqlen.seqlen_q, head_idx)
424+
else:
425+
if const_expr(self.use_tma_O):
426+
# ensure smem writes are visible to TMA
427+
cute.arch.fence_view_async_shared()
428+
cute.arch.barrier_arrive(
429+
barrier_id=int(NamedBarrierFwd.Epilogue),
430+
number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE,
431+
)
432+
gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0))
433+
store_O, _, _ = copy_utils.tma_get_copy_fn(
434+
tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True
435+
)
436+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
437+
if warp_idx == 4:
438+
cute.arch.barrier(
439+
barrier_id=int(NamedBarrierFwd.Epilogue),
440+
number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE,
441+
)
442+
store_O()
443+
cute.arch.cp_async_bulk_commit_group()
444+
cute.arch.cp_async_bulk_wait_group(0, read=True)
445+
else:
446+
cute.arch.barrier(
447+
barrier_id=int(NamedBarrierFwd.Epilogue),
448+
number_of_threads=self.num_epilogue_threads,
449+
)
450+
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
451+
tOsO = gmem_thr_copy_O.partition_S(sO)
452+
tOrO = cute.make_fragment_like(tOsO, self.dtype)
453+
# load acc O from smem to rmem for wider vectorization
454+
cute.autovec_copy(tOsO, tOrO)
455+
if const_expr(not self.pack_gqa):
456+
gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0))
457+
tOgO = gmem_thr_copy_O.partition_D(gO)
458+
tOcO = gmem_thr_copy_O.partition_S(cO)
459+
t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
460+
tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])
461+
# copy acc O from rmem to gmem
462+
for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
463+
if (
464+
t0OcO[0, rest_m, 0][0]
465+
< seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0]
466+
):
467+
cute.copy(
468+
gmem_tiled_copy_O,
469+
tOrO[None, rest_m, None],
470+
tOgO[None, rest_m, None],
471+
pred=tOpO[None, rest_m, None]
472+
if const_expr(self.check_hdim_v_oob)
473+
else None,
474+
)
475+
else:
476+
pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q)
443477

444478
@cute.jit
445479
def advance_pipeline(self, pipeline_index):

0 commit comments

Comments
 (0)