-
Notifications
You must be signed in to change notification settings - Fork 270
[CK_TILE] Stream-K XCD remapping #3652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
This change adds in a function to remap block ids from their original round robin assignment to a contiguous layout across XCDs. This function is added to the StreamKTilePartitioner and called in the operator() functions. There are also unit tests to verify the correctness of the function on minimal arrays. These changes should improve locality and the cache hit rate, therefore improving performance overall.
| * @param NUM_XCDS number of XCDs | ||
| * @return index_t The id after XCD remap | ||
| */ | ||
| CK_TILE_HOST_DEVICE index_t RemapXCD(index_t block_1d_id, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the sake of keeping things consistent, can we up lower case snake for this function name?
I think there is a ticket to update all casing in this class, but for now, I think it might be best to keep the style the same within the file.
| * | ||
| * @param block_1d_id grid 1D id | ||
| * @param total_num_tiles size of the 1D grid | ||
| * @param NUM_XCDS number of XCDs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this param lower case snake?
| index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas(); | ||
| bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas(); | ||
|
|
||
| block_idx = kargs.tile_partitioner.RemapXCD(block_idx, grid_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this change should only apply to gfx942 and gfx950, we likely need some kind of logic to determine whether we want to apply the remap or not. Also, depending on the variant of gfx942 there may be fewer than 8 XCDs, so I think some additional logic will be required to determine what value for num_xcds we want to use (rather than solely relying on the default of 8).
We want to avoid using preprocessor macros, so depending on what there is available in the HIP api, maybe we can query the number of XCDs from the device?
If not, we may need to consider other options. Perhaps we could consider some logic that follows a similar pattern to: include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp.
| for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles(); | ||
| tile_idx += kargs.tile_partitioner.get_grid()) | ||
| block_idx = | ||
| kargs.tile_partitioner.RemapXCD(block_idx, grid_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment about how we shouldn't do this on all architectures. (See above)
| EXPECT_EQ(tile_local_cta_idx, expected_tile_local_cta_idx); | ||
| } | ||
|
|
||
| template <typename GemmShape> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add a test for when we don't use the default number of XCDs?
Proposed changes
This PR adds support for XCD remapping as detailed in this document. On gfx942, workgroups are typically scheduled round-robin across XCDs, which can lead to poor locality. We will use a remapping to assign workgroups to contiguous tiles in the XCDs improving the locality and the cache hit rate. This is done through a function that computes this contiguous mapping from this PR, which we have added to the StreamKTilePartitioner. This will require minimal changes to the Stream-K algorithm, only requiring a remap at the time the workgroups are partitioned. Through this approach we can improve the data locality by improving cache hits therefore closing performance gaps that are seen with the default scheduling. There have been unit tests added to verify the function in isolation. This is an optimization that is not specialized to just Stream-K GEMM and can be applied across GEMM.
Note: This only applies to the gfx942 as they introduce the XCDs.
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed files