-
Notifications
You must be signed in to change notification settings - Fork 825
[HLK] Modify Wave match test logic to support modifications in different lanes and vector position #7991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[HLK] Modify Wave match test logic to support modifications in different lanes and vector position #7991
Changes from all commits
e89f1f3
343df5a
d1f0d9e
719ade7
2eee5b6
64779aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1614,41 +1614,88 @@ template <typename T> T waveMultiPrefixProduct(T A, UINT) { | |
|
|
||
| template <typename T> struct Op<OpType::WaveMatch, T, 1> : StrictValidation {}; | ||
|
|
||
| static constexpr UINT ComputeWaveMask(UINT NumWaves) { | ||
| return (NumWaves < 64) ? (1ULL << NumWaves) - 1 : ~0ULL; | ||
| } | ||
|
|
||
| // Helper struct to build the expected result for WaveMatch tests. | ||
| struct WaveMatchResultBuilder { | ||
|
|
||
| private: | ||
| uint64_t LowWaveMask; | ||
| uint64_t HighWaveMask; | ||
| uint64_t ActiveLanesLow; | ||
| uint64_t ActiveLanesHigh; | ||
|
|
||
| public: | ||
| WaveMatchResultBuilder(UINT NumWaves) | ||
| : ActiveLanesLow(0), ActiveLanesHigh(0) { | ||
| VERIFY_IS_TRUE(NumWaves <= 128); | ||
| const UINT LowWaves = std::min(64U, NumWaves); | ||
| const UINT HighWaves = NumWaves - LowWaves; | ||
| LowWaveMask = ComputeWaveMask(LowWaves); | ||
| HighWaveMask = ComputeWaveMask(HighWaves); | ||
| } | ||
|
|
||
| void SetLane(UINT LaneID) { | ||
| if (LaneID < 64) | ||
| ActiveLanesLow |= (1ULL << LaneID) & LowWaveMask; | ||
| else | ||
| ActiveLanesHigh |= (1ULL << (LaneID - 64)) & HighWaveMask; | ||
| } | ||
|
|
||
| void ClearLane(UINT LaneID) { | ||
| if (LaneID < 64) | ||
| ActiveLanesLow &= ~(1ULL << LaneID) & LowWaveMask; | ||
| else | ||
| ActiveLanesHigh &= ~(1ULL << (LaneID - 64)) & HighWaveMask; | ||
| } | ||
|
|
||
| void InvertLanes() { | ||
| ActiveLanesLow = ~ActiveLanesLow & LowWaveMask; | ||
| ActiveLanesHigh = ~ActiveLanesHigh & HighWaveMask; | ||
| } | ||
|
|
||
| void ComputeExpected(UINT *Dest) { | ||
| Dest[0] = static_cast<UINT>(ActiveLanesLow); | ||
| Dest[1] = static_cast<UINT>(ActiveLanesLow >> 32); | ||
| Dest[2] = static_cast<UINT>(ActiveLanesHigh); | ||
| Dest[3] = static_cast<UINT>(ActiveLanesHigh >> 32); | ||
| } | ||
| }; | ||
|
|
||
| template <typename T> struct ExpectedBuilder<OpType::WaveMatch, T> { | ||
| static std::vector<UINT> buildExpected(Op<OpType::WaveMatch, T, 1> &, | ||
| const InputSets<T> &, | ||
| const UINT WaveSize) { | ||
| // For this test, the shader arranges it so that lane 0 is different from | ||
| // all the other lanes. Besides that all other lines write their result of | ||
| // WaveMatch as well. | ||
| // For this test, the shader arranges it so that lanes 0, WAVE_SIZE/2 and | ||
| // WAVE_SIZE-1 are different from all the other lanes, also those | ||
| // lanes modify the vector at positions 0, WAVE_SIZE/2 and WAVE_SIZE-1 | ||
| // respectively, if the input vector has enough elements. Besides that all | ||
| // other lanes write their result of WaveMatch as well. | ||
|
|
||
| std::vector<UINT> Expected; | ||
| Expected.assign(WaveSize * 4, 0); | ||
|
|
||
| const UINT LowWaves = std::min(64U, WaveSize); | ||
| const UINT HighWaves = WaveSize - LowWaves; | ||
|
|
||
| const uint64_t LowWaveMask = | ||
| (LowWaves < 64) ? (1ULL << LowWaves) - 1 : ~0ULL; | ||
| const UINT MidLaneID = WaveSize / 2; | ||
| const UINT LastLaneID = WaveSize - 1; | ||
|
|
||
| const uint64_t HighWaveMask = | ||
| (HighWaves < 64) ? (1ULL << HighWaves) - 1 : ~0ULL; | ||
| WaveMatchResultBuilder UnchangedLanes(WaveSize); | ||
| UnchangedLanes.InvertLanes(); | ||
| UnchangedLanes.ClearLane(0); | ||
| UnchangedLanes.ClearLane(MidLaneID); | ||
| UnchangedLanes.ClearLane(LastLaneID); | ||
|
|
||
| const uint64_t LowExpected = ~1ULL & LowWaveMask; | ||
| const uint64_t HighExpected = ~0ULL & HighWaveMask; | ||
| for (UINT LaneID = 0; LaneID < WaveSize; ++LaneID) { | ||
| const UINT Index = LaneID * 4; | ||
|
|
||
| Expected[0] = 1; | ||
| Expected[1] = 0; | ||
| Expected[2] = 0; | ||
| Expected[3] = 0; | ||
|
|
||
| // all lanes other than the first one have the same result | ||
| for (UINT I = 1; I < WaveSize; ++I) { | ||
| const UINT Index = I * 4; | ||
| Expected[Index] = static_cast<UINT>(LowExpected); | ||
| Expected[Index + 1] = static_cast<UINT>(LowExpected >> 32); | ||
| Expected[Index + 2] = static_cast<UINT>(HighExpected); | ||
| Expected[Index + 3] = static_cast<UINT>(HighExpected >> 32); | ||
| if (LaneID == 0 || LaneID == MidLaneID || LaneID == LastLaneID) { | ||
| WaveMatchResultBuilder ChangedLanes(WaveSize); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What were you intending to do with ChangedLanes? It's declared as a local in this for loop and not used outside of it.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
WaveMatchResultBuilder ChangedLanes(WaveSize);
for (UINT LaneID = 0; LaneID < WaveSize; ++LaneID) {
const UINT Index = LaneID * 4;
if (LaneID == 0 || LaneID == MidLaneID || LaneID == LastLaneID) {
ChangedLanes.SetLane(LaneID);
ChangedLanes.SetExpected(&Expected[Index]);
ChangedLanes.ClearLane(LaneID);
continue;
}
UnchangedLanes.SetExpected(&Expected[Index]);
}Is that better @alsepkow? |
||
| ChangedLanes.SetLane(LaneID); | ||
| ChangedLanes.ComputeExpected(&Expected[Index]); | ||
| continue; | ||
| } | ||
| UnchangedLanes.ComputeExpected(&Expected[Index]); | ||
| } | ||
|
|
||
| return Expected; | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.