Skip to content

Commit 08185b9

Browse files
authored
Update blackwell tutorial to be compatible with 4.5-dev version (#3130)
* Update blackwell tutorial to be compatible with 4.5-dev version * update example for reverted changes * add more example fix
1 parent bd01dd3 commit 08185b9

12 files changed

Lines changed: 29 additions & 29 deletions

examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent_prefetch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ class SharedStorage:
647647
ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
648648
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
649649
acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
650-
tmem_dealloc_mbar_ptr: cutlass.Int64
650+
tmem_dealloc_mbar: cutlass.Int64
651651
tmem_holding_buf: cutlass.Int32
652652
# (EPI_TILE_M, EPI_TILE_N, STAGE)
653653
sC: cute.struct.Align[
@@ -826,11 +826,11 @@ def kernel(
826826

827827
# Tensor memory dealloc barrier init
828828
tmem = utils.TmemAllocator(
829-
storage.tmem_holding_buf,
829+
storage.tmem_holding_buf.ptr,
830830
barrier_for_retrieve=self.tmem_alloc_barrier,
831831
allocator_warp_id=self.epilog_warp_id[0],
832832
is_two_cta=use_2cta_instrs,
833-
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
833+
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
834834
)
835835

836836
# Cluster arrive after barrier init

examples/python/CuTeDSL/blackwell/dense_gemm_persistent_prefetch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ class SharedStorage:
648648
acc_full_mbar_ptr: cute.struct.MemRange[
649649
cutlass.Int64, self.num_acc_stage * 2
650650
]
651-
tmem_dealloc_mbar_ptr: cutlass.Int64
651+
tmem_dealloc_mbar: cutlass.Int64
652652
tmem_holding_buf: cutlass.Int32
653653

654654
smem = utils.SmemAllocator()
@@ -699,11 +699,11 @@ class SharedStorage:
699699
)
700700
# Tensor memory dealloc barrier init
701701
tmem = utils.TmemAllocator(
702-
storage.tmem_holding_buf,
702+
storage.tmem_holding_buf.ptr,
703703
barrier_for_retrieve=tmem_alloc_barrier,
704704
allocator_warp_id=self.epilog_warp_id[0],
705705
is_two_cta=use_2cta_instrs,
706-
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
706+
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
707707
)
708708

709709
# Cluster arrive after barrier init

examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,11 +219,11 @@ def kernel(
219219
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
220220
)
221221
tmem = utils.TmemAllocator(
222-
storage.tmem_holding_buffer,
222+
storage.tmem_holding_buffer.ptr,
223223
barrier_for_retrieve=tmem_alloc_barrier,
224224
allocator_warp_id=epilogue_warp_ids[0],
225225
is_two_cta=True if use_2cta_instrs else False,
226-
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
226+
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
227227
)
228228

229229
# Partition tensors for TMA; This requires the tensors partitioned for MMA

examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,11 @@ def kernel(
152152
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
153153
)
154154
tmem = utils.TmemAllocator(
155-
storage.tmem_holding_buffer,
155+
storage.tmem_holding_buffer.ptr,
156156
barrier_for_retrieve=tmem_alloc_barrier,
157157
allocator_warp_id=epilogue_warp_ids[0],
158158
is_two_cta=True,
159-
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
159+
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
160160
)
161161

162162
num_tma_copy_bytes = (

examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3_1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,11 @@ def kernel(
159159
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
160160
)
161161
tmem = utils.TmemAllocator(
162-
storage.tmem_holding_buffer,
162+
storage.tmem_holding_buffer.ptr,
163163
barrier_for_retrieve=tmem_alloc_barrier,
164164
allocator_warp_id=epilogue_warp_ids[0],
165165
is_two_cta=True,
166-
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
166+
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
167167
)
168168

169169
num_tma_copy_bytes = (

examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,11 @@ def cluster_specific_kernel(
184184
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
185185
)
186186
tmem = utils.TmemAllocator(
187-
storage.tmem_holding_buffer,
187+
storage.tmem_holding_buffer.ptr,
188188
barrier_for_retrieve=tmem_alloc_barrier,
189189
allocator_warp_id=epilogue_warp_ids[0],
190190
is_two_cta=True,
191-
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
191+
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
192192
)
193193

194194
num_tma_copy_bytes = (

examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_5.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,11 @@ def kernel(
171171
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
172172
)
173173
tmem = utils.TmemAllocator(
174-
storage.tmem_holding_buffer,
174+
storage.tmem_holding_buffer.ptr,
175175
barrier_for_retrieve=tmem_alloc_barrier,
176176
allocator_warp_id=epilogue_warp_ids[0],
177177
is_two_cta=True,
178-
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
178+
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
179179
)
180180

181181
num_tma_copy_bytes = (

examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_6.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,11 @@ def gemm(
214214
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
215215
)
216216
tmem = utils.TmemAllocator(
217-
storage.tmem_holding_buffer,
217+
storage.tmem_holding_buffer.ptr,
218218
barrier_for_retrieve=tmem_alloc_barrier,
219219
allocator_warp_id=epilogue_warp_ids[0],
220220
is_two_cta=True,
221-
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
221+
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
222222
)
223223

224224
num_tma_copy_bytes = (

examples/python/CuTeDSL/distributed/distributed_all_gather_gemm_blackwell.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ class SharedStorage:
756756
acc_full_mbar_ptr: cute.struct.MemRange[
757757
cutlass.Int64, self.num_acc_stage * 2
758758
]
759-
tmem_dealloc_mbar_ptr: cutlass.Int64
759+
tmem_dealloc_mbar: cutlass.Int64
760760
tmem_holding_buf: cutlass.Int32
761761

762762
smem = utils.SmemAllocator()
@@ -806,11 +806,11 @@ class SharedStorage:
806806
)
807807
# Tensor memory dealloc barrier init
808808
tmem = utils.TmemAllocator(
809-
storage.tmem_holding_buf,
809+
storage.tmem_holding_buf.ptr,
810810
barrier_for_retrieve=tmem_alloc_barrier,
811811
allocator_warp_id=self.epilog_warp_id[0],
812812
is_two_cta=use_2cta_instrs,
813-
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
813+
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
814814
)
815815

816816
# Cluster arrive after barrier init

examples/python/CuTeDSL/distributed/distributed_gemm_all_reduce_blackwell.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ class SharedStorage:
672672
acc_full_mbar_ptr: cute.struct.MemRange[
673673
cutlass.Int64, self.num_acc_stage * 2
674674
]
675-
tmem_dealloc_mbar_ptr: cutlass.Int64
675+
tmem_dealloc_mbar: cutlass.Int64
676676
tmem_holding_buf: cutlass.Int32
677677

678678
smem = utils.SmemAllocator()
@@ -723,11 +723,11 @@ class SharedStorage:
723723
)
724724
# Tensor memory dealloc barrier init
725725
tmem = utils.TmemAllocator(
726-
storage.tmem_holding_buf,
726+
storage.tmem_holding_buf.ptr,
727727
barrier_for_retrieve=tmem_alloc_barrier,
728728
allocator_warp_id=self.epilogue_warp_id[0],
729729
is_two_cta=use_2cta_instrs,
730-
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
730+
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
731731
)
732732

733733
# Cluster arrive after barrier init

0 commit comments

Comments
 (0)