diff --git a/tools/clang/unittests/HLSLExec/LongVectors.cpp b/tools/clang/unittests/HLSLExec/LongVectors.cpp index 59b4d30359..96ae691fdb 100644 --- a/tools/clang/unittests/HLSLExec/LongVectors.cpp +++ b/tools/clang/unittests/HLSLExec/LongVectors.cpp @@ -1614,41 +1614,59 @@ template T waveMultiPrefixProduct(T A, UINT) { template struct Op : StrictValidation {}; +uint32_t GetWord(const std::bitset<128> &b, uint32_t WordPos) { + uint32_t Word = 0; + for (uint32_t I = 0; I < 32; ++I) + Word |= uint32_t(b[WordPos * 32 + I]) << I; + return Word; +} + +void StoreWords(UINT *Dest, std::bitset<128> LanesState) { + Dest[0] = GetWord(LanesState, 0); + Dest[1] = GetWord(LanesState, 1); + Dest[2] = GetWord(LanesState, 2); + Dest[3] = GetWord(LanesState, 3); +} + template struct ExpectedBuilder { static std::vector buildExpected(Op &, - const InputSets &, + const InputSets &Inputs, const UINT WaveSize) { - // For this test, the shader arranges it so that lane 0 is different from - // all the other lanes. Besides that all other lines write their result of - // WaveMatch as well. + // For this test, the shader arranges it so that lanes 0, WAVE_SIZE/2 and + // WAVE_SIZE-1 are different from all the other lanes, also those + // lanes modify the vector at positions 0, WAVE_SIZE/2 and WAVE_SIZE-1 + // respectively, if the input vector has enough elements. Besides that all + // other lanes write their result of WaveMatch as well. + DXASSERT_NOMSG(Inputs.size() == 1); std::vector Expected; Expected.assign(WaveSize * 4, 0); - const UINT LowWaves = std::min(64U, WaveSize); - const UINT HighWaves = WaveSize - LowWaves; + const size_t VectorSize = Inputs[0].size(); + + Expected.assign(WaveSize * 4, 0); - const uint64_t LowWaveMask = - (LowWaves < 64) ? (1ULL << LowWaves) - 1 : ~0ULL; + const UINT MidLaneID = WaveSize / 2; + const UINT LastLaneID = WaveSize - 1; - const uint64_t HighWaveMask = - (HighWaves < 64) ? (1ULL << HighWaves) - 1 : ~0ULL; + std::bitset<128> UnchangedLanes(~0ULL); + UnchangedLanes &= (1ULL << WaveSize) - 1; + UnchangedLanes.reset(0).reset(MidLaneID); - const uint64_t LowExpected = ~1ULL & LowWaveMask; - const uint64_t HighExpected = ~0ULL & HighWaveMask; + if (LastLaneID < VectorSize) + UnchangedLanes.reset(LastLaneID); - Expected[0] = 1; - Expected[1] = 0; - Expected[2] = 0; - Expected[3] = 0; + for (UINT LaneID = 0; LaneID < WaveSize; ++LaneID) { + const UINT Index = LaneID * 4; - // all lanes other than the first one have the same result - for (UINT I = 1; I < WaveSize; ++I) { - const UINT Index = I * 4; - Expected[Index] = static_cast(LowExpected); - Expected[Index + 1] = static_cast(LowExpected >> 32); - Expected[Index + 2] = static_cast(HighExpected); - Expected[Index + 3] = static_cast(HighExpected >> 32); + if (LaneID == 0 || LaneID == MidLaneID || + (LastLaneID < VectorSize && LaneID == LastLaneID)) { + std::bitset<128> ChangedLanes(0); + ChangedLanes = ChangedLanes.set(LaneID); + StoreWords(&Expected[Index], ChangedLanes); + continue; + } + StoreWords(&Expected[Index], UnchangedLanes); } return Expected; diff --git a/tools/clang/unittests/HLSLExec/ShaderOpArith.xml b/tools/clang/unittests/HLSLExec/ShaderOpArith.xml index 2cfeb1f225..d7c48749a6 100644 --- a/tools/clang/unittests/HLSLExec/ShaderOpArith.xml +++ b/tools/clang/unittests/HLSLExec/ShaderOpArith.xml @@ -4408,15 +4408,24 @@ void MSMain(uint GID : SV_GroupIndex, #ifdef FUNC_WAVE_MATCH void TestWaveMatch(vector Vector) { - if(WaveGetLaneIndex() == 0) + uint LaneIndex = WaveGetLaneIndex(); + bool ShouldModify = ( LaneIndex == 0 || + LaneIndex == (WAVE_SIZE / 2) || + LaneIndex == (WAVE_SIZE - 1)); + + if(LaneIndex < NUM && ShouldModify) { - if(Vector[0] == (TYPE)0) - Vector[0] = (TYPE) 1; - else if(Vector[0] == (TYPE)1) - Vector[0] = (TYPE) 0; + if(Vector[LaneIndex] == (TYPE) 0) + Vector[LaneIndex] = (TYPE) 1; + else if(Vector[LaneIndex] == (TYPE) 1) + Vector[LaneIndex] = (TYPE) 0; else - Vector[0] = (TYPE) 1; + Vector[LaneIndex] = (TYPE) 1; } + + // Making sure all lanes finish updating their vectors. + AllMemoryBarrierWithGroupSync(); + uint4 result = WaveMatch(Vector); uint index = WaveGetLaneIndex();