diff --git a/zstd/zstdgpu/Shaders/ZstdGpuComputePrefixSum.hlsl b/zstd/zstdgpu/Shaders/ZstdGpuComputePrefixSum.hlsl index 044a1cd..1d0f7ae 100644 --- a/zstd/zstdgpu/Shaders/ZstdGpuComputePrefixSum.hlsl +++ b/zstd/zstdgpu/Shaders/ZstdGpuComputePrefixSum.hlsl @@ -22,6 +22,7 @@ struct Consts { uint32_t elemToPrefixCount; + uint32_t literalsPerGroup; }; ConstantBuffer Constants : register(b0); @@ -37,7 +38,7 @@ RWStructuredBuffer ZstdLitGroupCountToPrefixLookback : register(u RWStructuredBuffer ZstdCounters : register(u4); -[RootSignature("UAV(u0), UAV(u1), UAV(u2), UAV(u3), UAV(u4), RootConstants(b0, num32BitConstants=1)")] +[RootSignature("UAV(u0), UAV(u1), UAV(u2), UAV(u3), UAV(u4), RootConstants(b0, num32BitConstants=2)")] [numthreads(kzstdgpu_TgSizeX_PrefixSum_LiteralCount, 1, 1)] void main(uint i : SV_DispatchThreadId) { @@ -51,7 +52,7 @@ void main(uint i : SV_DispatchThreadId) const uint32_t lastLocalIndex = WaveActiveCountBits(true) - 1u; const uint32_t streamCount = ZstdLitStreamCountToPrefix[i]; - const uint32_t groupCount = ZSTDGPU_TG_COUNT(streamCount, kzstdgpu_TgSizeX_DecompressLiterals); + const uint32_t groupCount = ZSTDGPU_TG_COUNT(streamCount, Constants.literalsPerGroup); const uint32_t streamPrefix = WavePrefixSum(streamCount); const uint32_t groupPrefix = WavePrefixSum(groupCount); diff --git a/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache.hlsli b/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache.hlsli new file mode 100644 index 0000000..a13e740 --- /dev/null +++ b/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache.hlsli @@ -0,0 +1,159 @@ +/** + * ZstdGpuDecompressLiterals_LdsStoreCache.hlsli + * + * A compute shader that decompresses Huffman-compressed literals using + * an LDS store cache for cooperative dword-aligned writes. + * + * The Huffman table is packed: two symbol+bitcnt pairs per dword (each pair + * is 16 bits: 8-bit symbol | 8-bit bitcnt). Decoded literals are first + * accumulated into dwords, staged in an LDS cache, and then cooperatively + * flushed to device memory via a dword-typed UAV. + * + * The following must be defined before including this file: + * 'kzstdgpu_TgSizeX_DecompressLiterals_LdsStoreCache' + * -- threadgroup size, also used as + * the number of dwords cached in LDS + * per decoded literal stream. + * 'kzstdgpu_DecompressLiterals_StreamsPerGroup' -- number of literal streams processed + * per threadgroup. + * + * Copyright (c) Microsoft. All rights reserved. + * This code is licensed under the MIT License (MIT). + * THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF + * ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY + * IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR + * PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT. + * + * Advanced Technology Group (ATG) + * Author(s): Pavel Martishevsky (pamartis@microsoft.com) + */ + +#ifdef __XBOX_SCARLETT +#define __XBOX_ENABLE_WAVE32 1 +#endif + +#ifndef kzstdgpu_TgSizeX_DecompressLiterals_LdsStoreCache +# error 'kzstdgpu_TgSizeX_DecompressLiterals_LdsStoreCache' must be defined before including this '.hlsli' +#endif + +#ifndef kzstdgpu_DecompressLiterals_StreamsPerGroup +# error 'kzstdgpu_DecompressLiterals_StreamsPerGroup' must be defined before including this '.hlsli' +#endif + +#include "../zstdgpu_shaders.h" + +// LDS layout for the LdsStoreCache variant: Huffman table + per-stream store cache. +// Cache size per stream equals the threadgroup size (kzstdgpu_TgSizeX_DecompressLiterals_LdsStoreCache). +#define ZSTDGPU_DECOMPRESS_LITERALS_LDS_STORE_CACHE(base, size) \ + ZSTDGPU_LDS_SIZE(size) \ + ZSTDGPU_LDS_BASE(base) \ + ZSTDGPU_LDS_REGION(HuffmanTable, 1u << (kzstdgpu_MaxCount_HuffmanWeightBits - 1)) \ + ZSTDGPU_LDS_REGION(LiteralStoreCache, kzstdgpu_DecompressLiterals_StreamsPerGroup * kzstdgpu_TgSizeX_DecompressLiterals_LdsStoreCache) + +#include "../zstdgpu_lds_decl_size.h" +ZSTDGPU_DECOMPRESS_LITERALS_LDS_STORE_CACHE(0, DecompressLiterals_LdsStoreCache); +#include "../zstdgpu_lds_decl_undef.h" + +struct Consts +{ + uint32_t huffmanTableSlotCount; +}; + +ConstantBuffer Constants : register(b0); + +#include "../zstdgpu_srt_decl_bind.h" +ZSTDGPU_DECOMPRESS_LITERALS_SRT() +ZSTDGPU_RW_BUFFER_DECL(uint32_t, DecompressedLiteralsAsDwords, 1) +#include "../zstdgpu_srt_decl_undef.h" + +groupshared uint32_t GS_Lds[kzstdgpu_DecompressLiterals_LdsStoreCache_LdsSize]; +#define ZSTDGPU_LDS GS_Lds +#include "../zstdgpu_lds_hlsl.h" + +[RootSignature("DescriptorTable(SRV(t0, numDescriptors=9), UAV(u0, numDescriptors=1)),UAV(u1), RootConstants(b0, num32BitConstants=1)")] +[numthreads(kzstdgpu_TgSizeX_DecompressLiterals_LdsStoreCache, 1, 1)] +void main(uint groupId : SV_GroupId, uint i : SV_GroupThreadId) +{ + zstdgpu_DecompressLiterals_SRT srt; + + #include "../zstdgpu_srt_decl_copy.h" + ZSTDGPU_DECOMPRESS_LITERALS_SRT() + #include "../zstdgpu_srt_decl_undef.h" + srt.huffmanTableSlotCount = Constants.huffmanTableSlotCount; + + if (groupId >= srt.inCounters[kzstdgpu_CounterIndex_DecompressLiteralsGroups]) + return; + + uint32_t htIndex = 0; + uint32_t htGroupStart = 0; + uint32_t htLiteralStart = 0; + uint32_t htLiteralCount = 0; + + zstdgpu_ConvertThreadgroupIdToDecompressLiteralsInputs( + srt.inLitGroupEndPerHuffmanTable, + srt.inLitStreamEndPerHuffmanTable, + srt.huffmanTableSlotCount, + groupId, + htIndex, + htGroupStart, + htLiteralStart, + htLiteralCount + ); + + #include "../zstdgpu_lds_decl_base.h" + ZSTDGPU_DECOMPRESS_LITERALS_LDS_STORE_CACHE(0, DecompressLiterals_LdsStoreCache); + #include "../zstdgpu_lds_decl_undef.h" + + const uint32_t htInfo = WaveReadLaneFirst(srt.inHuffmanTableInfo[htIndex]); + const uint32_t bitsMax = htInfo >> 16; + const uint32_t codeTableSize = htInfo & 0xffffu; + const uint32_t stateCnt = WaveReadLaneFirst(srt.inHuffmanTableRankIndex[htIndex * kzstdgpu_MaxCount_HuffmanWeightRanks + bitsMax]); + const uint32_t statePairCnt = stateCnt >> 1u; + + // Expand Huffman Table — pack two symbol+bitcnt pairs per dword + ZSTDGPU_FOR_WORK_ITEMS(statePairId, statePairCnt, i, kzstdgpu_TgSizeX_DecompressLiterals_LdsStoreCache) + { + const uint32_t stateId0 = statePairId << 1u; + const uint32_t stateId1 = stateId0 + 1u; + + const uint32_t symbolIndex0 = zstdgpu_BinarySearchMasked(srt.inHuffmanTableCodeAndSymbol, htIndex * kzstdgpu_MaxCount_HuffmanWeights, codeTableSize, stateId0, 0x00ffffffu); + const uint32_t symbolIndex1 = zstdgpu_BinarySearchMasked(srt.inHuffmanTableCodeAndSymbol, htIndex * kzstdgpu_MaxCount_HuffmanWeights, codeTableSize, stateId1, 0x00ffffffu); + + const uint32_t bitcntIndex0 = zstdgpu_BinarySearchMasked(srt.inHuffmanTableRankIndex, htIndex * kzstdgpu_MaxCount_HuffmanWeightRanks, bitsMax + 1, stateId0, 0xffffffffu) + - htIndex * kzstdgpu_MaxCount_HuffmanWeightRanks; + + const uint32_t bitcntIndex1 = zstdgpu_BinarySearchMasked(srt.inHuffmanTableRankIndex, htIndex * kzstdgpu_MaxCount_HuffmanWeightRanks, bitsMax + 1, stateId1, 0xffffffffu) + - htIndex * kzstdgpu_MaxCount_HuffmanWeightRanks; + + const uint32_t symbol0 = srt.inHuffmanTableCodeAndSymbol[symbolIndex0] >> 24; + const uint32_t bitcnt0 = bitsMax - bitcntIndex0; + + const uint32_t symbol1 = srt.inHuffmanTableCodeAndSymbol[symbolIndex1] >> 24; + const uint32_t bitcnt1 = bitsMax - bitcntIndex1; + + const uint32_t symbolAndBitcnt0 = (symbol0 << 8) | bitcnt0; + const uint32_t symbolAndBitcnt1 = (symbol1 << 8) | bitcnt1; + + zstdgpu_LdsStoreU32(GS_HuffmanTable + statePairId, (symbolAndBitcnt1 << 16) | symbolAndBitcnt0); + } + GroupMemoryBarrierWithGroupSync(); + + zstdgpu_DecompressHuffmanCompressedLiterals_StoreLdsCache( + srt.inCompressedData, + srt.inLitStreamRemap, + srt.inLitRefs, + srt.inoutDecompressedLiterals, + ZstdInOutDecompressedLiteralsAsDwords, + GS_HuffmanTable, + GS_LiteralStoreCache, + groupId, + i, + htGroupStart, + htLiteralStart, + htLiteralCount, + bitsMax, + kzstdgpu_TgSizeX_DecompressLiterals_LdsStoreCache, + kzstdgpu_DecompressLiterals_StreamsPerGroup, + kzstdgpu_TgSizeX_DecompressLiterals_LdsStoreCache + ); +} diff --git a/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache128_8.hlsl b/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache128_8.hlsl new file mode 100644 index 0000000..e298e16 --- /dev/null +++ b/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache128_8.hlsl @@ -0,0 +1,19 @@ +/** + * ZstdGpuDecompressLiterals_LdsStoreCache128_8.hlsl + * + * LDS store cache variant: 128 threads, 8 streams per group, 128 dwords cached per stream. + * + * Copyright (c) Microsoft. All rights reserved. + * This code is licensed under the MIT License (MIT). + * THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF + * ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY + * IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR + * PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT. + * + * Advanced Technology Group (ATG) + * Author(s): Pavel Martishevsky (pamartis@microsoft.com) + */ + +#define kzstdgpu_TgSizeX_DecompressLiterals_LdsStoreCache 128 +#define kzstdgpu_DecompressLiterals_StreamsPerGroup 8 +#include "ZstdGpuDecompressLiterals_LdsStoreCache.hlsli" diff --git a/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache32_16.hlsl b/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache32_16.hlsl new file mode 100644 index 0000000..50d38c4 --- /dev/null +++ b/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache32_16.hlsl @@ -0,0 +1,19 @@ +/** + * ZstdGpuDecompressLiterals_LdsStoreCache32_16.hlsl + * + * LDS store cache variant: 32 threads, 16 streams per group, 32 dwords cached per stream. + * + * Copyright (c) Microsoft. All rights reserved. + * This code is licensed under the MIT License (MIT). + * THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF + * ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY + * IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR + * PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT. + * + * Advanced Technology Group (ATG) + * Author(s): Pavel Martishevsky (pamartis@microsoft.com) + */ + +#define kzstdgpu_TgSizeX_DecompressLiterals_LdsStoreCache 32 +#define kzstdgpu_DecompressLiterals_StreamsPerGroup 16 +#include "ZstdGpuDecompressLiterals_LdsStoreCache.hlsli" diff --git a/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache32_32.hlsl b/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache32_32.hlsl new file mode 100644 index 0000000..49f75cc --- /dev/null +++ b/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache32_32.hlsl @@ -0,0 +1,19 @@ +/** + * ZstdGpuDecompressLiterals_LdsStoreCache32_32.hlsl + * + * LDS store cache variant: 32 threads, 32 streams per group, 32 dwords cached per stream. + * + * Copyright (c) Microsoft. All rights reserved. + * This code is licensed under the MIT License (MIT). + * THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF + * ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY + * IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR + * PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT. + * + * Advanced Technology Group (ATG) + * Author(s): Pavel Martishevsky (pamartis@microsoft.com) + */ + +#define kzstdgpu_TgSizeX_DecompressLiterals_LdsStoreCache 32 +#define kzstdgpu_DecompressLiterals_StreamsPerGroup 32 +#include "ZstdGpuDecompressLiterals_LdsStoreCache.hlsli" diff --git a/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache32_8.hlsl b/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache32_8.hlsl new file mode 100644 index 0000000..e05988b --- /dev/null +++ b/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache32_8.hlsl @@ -0,0 +1,19 @@ +/** + * ZstdGpuDecompressLiterals_LdsStoreCache32_8.hlsl + * + * LDS store cache variant: 32 threads, 8 streams per group, 32 dwords cached per stream. + * + * Copyright (c) Microsoft. All rights reserved. + * This code is licensed under the MIT License (MIT). + * THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF + * ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY + * IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR + * PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT. + * + * Advanced Technology Group (ATG) + * Author(s): Pavel Martishevsky (pamartis@microsoft.com) + */ + +#define kzstdgpu_TgSizeX_DecompressLiterals_LdsStoreCache 32 +#define kzstdgpu_DecompressLiterals_StreamsPerGroup 8 +#include "ZstdGpuDecompressLiterals_LdsStoreCache.hlsli" diff --git a/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache64_16.hlsl b/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache64_16.hlsl new file mode 100644 index 0000000..5b28a8d --- /dev/null +++ b/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache64_16.hlsl @@ -0,0 +1,19 @@ +/** + * ZstdGpuDecompressLiterals_LdsStoreCache64_16.hlsl + * + * LDS store cache variant: 64 threads, 16 streams per group, 64 dwords cached per stream. + * + * Copyright (c) Microsoft. All rights reserved. + * This code is licensed under the MIT License (MIT). + * THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF + * ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY + * IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR + * PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT. + * + * Advanced Technology Group (ATG) + * Author(s): Pavel Martishevsky (pamartis@microsoft.com) + */ + +#define kzstdgpu_TgSizeX_DecompressLiterals_LdsStoreCache 64 +#define kzstdgpu_DecompressLiterals_StreamsPerGroup 16 +#include "ZstdGpuDecompressLiterals_LdsStoreCache.hlsli" diff --git a/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache64_8.hlsl b/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache64_8.hlsl new file mode 100644 index 0000000..ded408c --- /dev/null +++ b/zstd/zstdgpu/Shaders/ZstdGpuDecompressLiterals_LdsStoreCache64_8.hlsl @@ -0,0 +1,19 @@ +/** + * ZstdGpuDecompressLiterals_LdsStoreCache64_8.hlsl + * + * LDS store cache variant: 64 threads, 8 streams per group, 64 dwords cached per stream. + * + * Copyright (c) Microsoft. All rights reserved. + * This code is licensed under the MIT License (MIT). + * THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF + * ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY + * IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR + * PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT. + * + * Advanced Technology Group (ATG) + * Author(s): Pavel Martishevsky (pamartis@microsoft.com) + */ + +#define kzstdgpu_TgSizeX_DecompressLiterals_LdsStoreCache 64 +#define kzstdgpu_DecompressLiterals_StreamsPerGroup 8 +#include "ZstdGpuDecompressLiterals_LdsStoreCache.hlsli" diff --git a/zstd/zstdgpu/zstdgpu.cpp b/zstd/zstdgpu/zstdgpu.cpp index 29a976f..a89e305 100644 --- a/zstd/zstdgpu/zstdgpu.cpp +++ b/zstd/zstdgpu/zstdgpu.cpp @@ -39,6 +39,12 @@ #include "ZstdGpuDecodeHuffmanWeights.h" #include "ZstdGpuDecompressHuffmanWeights.h" #include "ZstdGpuDecompressLiterals.h" +#include "ZstdGpuDecompressLiterals_LdsStoreCache128_8.h" +#include "ZstdGpuDecompressLiterals_LdsStoreCache64_16.h" +#include "ZstdGpuDecompressLiterals_LdsStoreCache64_8.h" +#include "ZstdGpuDecompressLiterals_LdsStoreCache32_32.h" +#include "ZstdGpuDecompressLiterals_LdsStoreCache32_16.h" +#include "ZstdGpuDecompressLiterals_LdsStoreCache32_8.h" #include "ZstdGpuDecompressSequences.h" #include "ZstdGpuDecompressSequences_LdsFseCache128.h" #include "ZstdGpuDecompressSequences_LdsFseCache64.h" @@ -302,31 +308,37 @@ static void zstdgpu_ReCreate_SRTs(zstdgpu_SRTs & srts, ID3D12Device *device, con #undef ZSTDGPU_RO_RAW_BUFFER_DECL } -#define ZSTDGPU_KERNEL_LIST() \ - ZSTDGPU_KERNEL(ComputeDestSequenceOffsets , L"Compute Destination Sequence Offsets") \ - ZSTDGPU_KERNEL(ComputePrefixSum , L"Compute Prefix of Literal and TG Count for Literal Decompression") \ - ZSTDGPU_KERNEL(DecodeHuffmanWeights , L"Decode (from nibbles) Uncompressed Huffman Weights") \ - ZSTDGPU_KERNEL(DecompressHuffmanWeights , L"Decompress FSE-compressed Huffman Weights") \ - ZSTDGPU_KERNEL(DecompressLiterals , L"Decompress Literals") \ - ZSTDGPU_KERNEL(DecompressSequences , L"Decompress Sequences") \ - ZSTDGPU_KERNEL(DecompressSequences_LdsFseCache128 , L"Decompress Sequences (LDS FSE Cache 128)") \ - ZSTDGPU_KERNEL(DecompressSequences_LdsFseCache64 , L"Decompress Sequences (LDS FSE Cache 64)") \ - ZSTDGPU_KERNEL(DecompressSequences_LdsFseCache32 , L"Decompress Sequences (LDS FSE Cache 32)") \ - ZSTDGPU_KERNEL(ExecuteSequences128 , L"Execute Sequences 128") \ - ZSTDGPU_KERNEL(ExecuteSequences64 , L"Execute Sequences 64") \ - ZSTDGPU_KERNEL(ExecuteSequences32 , L"Execute Sequences 32") \ - ZSTDGPU_KERNEL(FinaliseSequenceOffsets , L"Finalise Sequence Offsets") \ - ZSTDGPU_KERNEL(GroupCompressedLiterals , L"Group Huffman-compressed Literals") \ - ZSTDGPU_KERNEL(InitFseTable , L"Init Fse Table") \ - ZSTDGPU_KERNEL(InitHuffmanTable , L"Init Huffman Table") \ - ZSTDGPU_KERNEL(InitHuffmanTableAndDecompressLiterals , L"Init Huffman Table and Decompress Literals") \ - ZSTDGPU_KERNEL(InitResources , L"Init Resources") \ - ZSTDGPU_KERNEL(MemsetMemcpy , L"Memset-Memcpy") \ - ZSTDGPU_KERNEL(ParseCompressedBlocks , L"Parse Compressed Blocks") \ - ZSTDGPU_KERNEL(ParseFrames , L"Parse Frames") \ - ZSTDGPU_KERNEL(PrefixSequenceOffsets , L"Prefix Sequence Offsets") \ - ZSTDGPU_KERNEL(PrefixSum , L"Prefix Sum") \ - ZSTDGPU_KERNEL(UpdateDispatchArgs , L"Update Dispatch Args") +#define ZSTDGPU_KERNEL_LIST() \ + ZSTDGPU_KERNEL(ComputeDestSequenceOffsets , L"Compute Destination Sequence Offsets") \ + ZSTDGPU_KERNEL(ComputePrefixSum , L"Compute Prefix of Literal and TG Count for Literal Decompression") \ + ZSTDGPU_KERNEL(DecodeHuffmanWeights , L"Decode (from nibbles) Uncompressed Huffman Weights") \ + ZSTDGPU_KERNEL(DecompressHuffmanWeights , L"Decompress FSE-compressed Huffman Weights") \ + ZSTDGPU_KERNEL(DecompressLiterals , L"Decompress Literals") \ + ZSTDGPU_KERNEL(DecompressLiterals_LdsStoreCache128_8 , L"Decompress Literals (LDS Store Cache=128 Dwords, Stream Count= 8)") \ + ZSTDGPU_KERNEL(DecompressLiterals_LdsStoreCache64_16 , L"Decompress Literals (LDS Store Cache= 64 Dwords, Stream Count=16)") \ + ZSTDGPU_KERNEL(DecompressLiterals_LdsStoreCache64_8 , L"Decompress Literals (LDS Store Cache= 64 Dwords, Stream Count= 8)") \ + ZSTDGPU_KERNEL(DecompressLiterals_LdsStoreCache32_32 , L"Decompress Literals (LDS Store Cache= 32 Dwords, Stream Count=32)") \ + ZSTDGPU_KERNEL(DecompressLiterals_LdsStoreCache32_16 , L"Decompress Literals (LDS Store Cache= 32 Dwords, Stream Count=16)") \ + ZSTDGPU_KERNEL(DecompressLiterals_LdsStoreCache32_8 , L"Decompress Literals (LDS Store Cache= 32 Dwords, Stream Count= 8)") \ + ZSTDGPU_KERNEL(DecompressSequences , L"Decompress Sequences") \ + ZSTDGPU_KERNEL(DecompressSequences_LdsFseCache128 , L"Decompress Sequences (LDS FSE Cache, TG Size= 128)") \ + ZSTDGPU_KERNEL(DecompressSequences_LdsFseCache64 , L"Decompress Sequences (LDS FSE Cache, TG Size= 64)") \ + ZSTDGPU_KERNEL(DecompressSequences_LdsFseCache32 , L"Decompress Sequences (LDS FSE Cache, TG Size= 32)") \ + ZSTDGPU_KERNEL(ExecuteSequences128 , L"Execute Sequences 128") \ + ZSTDGPU_KERNEL(ExecuteSequences64 , L"Execute Sequences 64") \ + ZSTDGPU_KERNEL(ExecuteSequences32 , L"Execute Sequences 32") \ + ZSTDGPU_KERNEL(FinaliseSequenceOffsets , L"Finalise Sequence Offsets") \ + ZSTDGPU_KERNEL(GroupCompressedLiterals , L"Group Huffman-compressed Literals") \ + ZSTDGPU_KERNEL(InitFseTable , L"Init Fse Table") \ + ZSTDGPU_KERNEL(InitHuffmanTable , L"Init Huffman Table") \ + ZSTDGPU_KERNEL(InitHuffmanTableAndDecompressLiterals , L"Init Huffman Table and Decompress Literals") \ + ZSTDGPU_KERNEL(InitResources , L"Init Resources") \ + ZSTDGPU_KERNEL(MemsetMemcpy , L"Memset-Memcpy") \ + ZSTDGPU_KERNEL(ParseCompressedBlocks , L"Parse Compressed Blocks") \ + ZSTDGPU_KERNEL(ParseFrames , L"Parse Frames") \ + ZSTDGPU_KERNEL(PrefixSequenceOffsets , L"Prefix Sequence Offsets") \ + ZSTDGPU_KERNEL(PrefixSum , L"Prefix Sum") \ + ZSTDGPU_KERNEL(UpdateDispatchArgs , L"Update Dispatch Args") #define ZSTDGPU_KERNEL_SCOPE_LIST_STAGE_0() \ ZSTDGPU_KERNEL_SCOPE_X(InitResources_CountBlocks , L"Init Resources" ) \ @@ -398,6 +410,8 @@ struct zstdgpu_PerRequestContextImpl #undef ZSTDGPU_KERNEL d3d12aid_ComputeRsPs ExecuteSequences; d3d12aid_ComputeRsPs DecompressSequences_LdsFseCache; + d3d12aid_ComputeRsPs DecompressLiterals_LdsStoreCache; + uint32_t DecompressLiterals_LdsStoreCache_StreamsPerGroup; zstdgpu_SRTs srts; zstdgpu_ResourceDataGpu resData; @@ -559,21 +573,29 @@ zstdgpu_Status zstdgpu_CreatePerRequestContext(zstdgpu_PerRequestContext *outPer #ifdef _GAMING_XBOX_SCARLETT context->ExecuteSequences = context->ExecuteSequences64; context->DecompressSequences_LdsFseCache = context->DecompressSequences_LdsFseCache32; + context->DecompressLiterals_LdsStoreCache = context->DecompressLiterals_LdsStoreCache32_16; + context->DecompressLiterals_LdsStoreCache_StreamsPerGroup = 16; #else if (persistentContext->maxLaneCount == 128) { context->ExecuteSequences = context->ExecuteSequences128; context->DecompressSequences_LdsFseCache = context->DecompressSequences_LdsFseCache128; + context->DecompressLiterals_LdsStoreCache = context->DecompressLiterals_LdsStoreCache128_8; + context->DecompressLiterals_LdsStoreCache_StreamsPerGroup = 8; } else if (persistentContext->maxLaneCount == 64) { context->ExecuteSequences = context->ExecuteSequences64; context->DecompressSequences_LdsFseCache = context->DecompressSequences_LdsFseCache64; + context->DecompressLiterals_LdsStoreCache = context->DecompressLiterals_LdsStoreCache64_16; + context->DecompressLiterals_LdsStoreCache_StreamsPerGroup = 16; } else { context->ExecuteSequences = context->ExecuteSequences32; context->DecompressSequences_LdsFseCache = context->DecompressSequences_LdsFseCache32; + context->DecompressLiterals_LdsStoreCache = context->DecompressLiterals_LdsStoreCache32_16; + context->DecompressLiterals_LdsStoreCache_StreamsPerGroup = 16; } #endif context->ExecuteSequences.rs->AddRef(); @@ -582,6 +604,9 @@ zstdgpu_Status zstdgpu_CreatePerRequestContext(zstdgpu_PerRequestContext *outPer context->DecompressSequences_LdsFseCache.rs->AddRef(); context->DecompressSequences_LdsFseCache.ps->AddRef(); + context->DecompressLiterals_LdsStoreCache.rs->AddRef(); + context->DecompressLiterals_LdsStoreCache.ps->AddRef(); + context->srts.heap = NULL; context->srts.heapOffset = 0; zstdgpu_ResourceDataGpu_InitZero(&context->resData); @@ -646,6 +671,7 @@ zstdgpu_Status zstdgpu_DestroyPerRequestContext(void **outMemoryBlock, uint32_t d3d12aid_ComputeRsPs_Release(&inPerRequestContext->ExecuteSequences); d3d12aid_ComputeRsPs_Release(&inPerRequestContext->DecompressSequences_LdsFseCache); + d3d12aid_ComputeRsPs_Release(&inPerRequestContext->DecompressLiterals_LdsStoreCache); #define ZSTDGPU_KERNEL(name, desc) d3d12aid_ComputeRsPs_Release(&inPerRequestContext->name); ZSTDGPU_KERNEL_LIST() #undef ZSTDGPU_KERNEL @@ -1478,6 +1504,13 @@ void zstdgpu_SubmitStage2(zstdgpu_PerRequestContext req, ID3D12GraphicsCommandLi cmdList->SetComputeRootUnorderedAccessView(3, req->resData.gpuOnly.LitGroupEndPerHuffmanTable->GetGPUVirtualAddress() + req->zstdCmpBlockCount * sizeof(uint32_t)); cmdList->SetComputeRootUnorderedAccessView(4, req->resData.gpuOnly.Counters->GetGPUVirtualAddress()); cmdList->SetComputeRoot32BitConstant(5, req->zstdCmpBlockCount, 0); +#if 0 + // NOTE(pamartis): Use this pass to with DecompressLiterals kernel + cmdList->SetComputeRoot32BitConstant(5, kzstdgpu_TgSizeX_DecompressLiterals, 1); +#else + // NOTE(pamartis): Use this path to with DecompressLiterals_LdsStoreCache* kernels + cmdList->SetComputeRoot32BitConstant(5, req->DecompressLiterals_LdsStoreCache_StreamsPerGroup, 1); +#endif ZSTDGPU_KERNEL_SCOPE(ComputePrefixSum, cmdList, cmdList->Dispatch(ZSTDGPU_TG_COUNT(req->zstdCmpBlockCount, kzstdgpu_TgSizeX_PrefixSum_LiteralCount), 1, 1); ); @@ -1679,6 +1712,7 @@ void zstdgpu_SubmitStage2(zstdgpu_PerRequestContext req, ID3D12GraphicsCommandLi if (req->zstdCmpBlockCount > 0) { +#if 0 PIXBeginEvent(cmdList, PIX_COLOR_DEFAULT, L"[Decompress Literals]"); BIND_RS_PS_SRT(DecompressLiterals); cmdList->SetComputeRoot32BitConstant(1, req->zstdCmpBlockCount, 0); @@ -1688,6 +1722,20 @@ void zstdgpu_SubmitStage2(zstdgpu_PerRequestContext req, ID3D12GraphicsCommandLi cmdList->ExecuteIndirect(req->dispatchCmdSig, 1, argBuf, kzstdgpu_CounterIndex_DecompressLiteralsGroups * sizeof(uint32_t), NULL, 0); ); PIXEndEvent(cmdList); +#else + PIXBeginEvent(cmdList, PIX_COLOR_DEFAULT, L"[Decompress Literals (LDS Store Cache)]"); + d3d12aid_ComputeRsPs_Set(&req->DecompressLiterals_LdsStoreCache, cmdList); + cmdList->SetDescriptorHeaps(1, &req->srts.heap); + cmdList->SetComputeRootDescriptorTable(0, req->srts.DecompressLiteralsGpuHandle); + cmdList->SetComputeRootUnorderedAccessView(1, req->resData.gpuOnly.DecompressedLiterals->GetGPUVirtualAddress()); + cmdList->SetComputeRoot32BitConstant(2, req->zstdCmpBlockCount, 0); + + ID3D12Resource* argBuf = req->resData.gpuOnly.Counters; + ZSTDGPU_KERNEL_SCOPE(DecompressLiterals, cmdList, + cmdList->ExecuteIndirect(req->dispatchCmdSig, 1, argBuf, kzstdgpu_CounterIndex_DecompressLiteralsGroups * sizeof(uint32_t), NULL, 0); + ); + PIXEndEvent(cmdList); +#endif } if (req->zstdCmpBlockCount > 0) diff --git a/zstd/zstdgpu/zstdgpu_lds_decl_base.h b/zstd/zstdgpu/zstdgpu_lds_decl_base.h index 4a5f98a..346871f 100644 --- a/zstd/zstdgpu/zstdgpu_lds_decl_base.h +++ b/zstd/zstdgpu/zstdgpu_lds_decl_base.h @@ -26,5 +26,5 @@ #endif #define ZSTDGPU_LDS_REGION(name, size) zstdgpu_lds_uintptr_t GS_##name = GS_Base; \ - GS_Base += size; + GS_Base += (size); diff --git a/zstd/zstdgpu/zstdgpu_shaders.h b/zstd/zstdgpu/zstdgpu_shaders.h index f9c296f..dde38ad 100644 --- a/zstd/zstdgpu/zstdgpu_shaders.h +++ b/zstd/zstdgpu/zstdgpu_shaders.h @@ -2785,6 +2785,23 @@ static inline void zstdgpu_DecompressHuffmanCompressedLiterals(ZSTDGPU_RO_RAW_BU uint32_t bitsMax, uint32_t tgSize); +static inline void zstdgpu_DecompressHuffmanCompressedLiterals_StoreLdsCache(ZSTDGPU_RO_RAW_BUFFER(uint32_t) CompressedData, + ZSTDGPU_RO_BUFFER(uint32_t) LitStreamRemap, + ZSTDGPU_RO_BUFFER(zstdgpu_LitStreamInfo) LitRefs, + ZSTDGPU_RW_TYPED_BUFFER(uint32_t, uint8_t) DecompressedLiterals, + ZSTDGPU_RW_BUFFER(uint32_t) DecompressedLiteralsAsDwords, + ZSTDGPU_PARAM_LDS_IN(uint32_t) GS_HuffmanTable, + ZSTDGPU_PARAM_LDS_INOUT(uint32_t) GS_LiteralStoreCache, + uint32_t groupId, + uint32_t threadId, + uint32_t htGroupStart, + uint32_t htLiteralStart, + uint32_t htLiteralCount, + uint32_t bitsMax, + uint32_t tgSize, + uint32_t streamsPerGroup, + uint32_t cacheDwordsPerStream); + static void zstdgpu_ConvertThreadgroupIdToDecompressLiteralsInputs(ZSTDGPU_RO_BUFFER(uint32_t) LitGroupEndPerHuffmanTable, ZSTDGPU_RO_BUFFER(uint32_t) LitStreamEndPerHuffmanTable, @@ -2839,7 +2856,7 @@ static void zstdgpu_ConvertThreadgroupIdToDecompressLiteralsInputs(ZSTDGPU_RO_BU #define ZSTDGPU_DECOMPRESS_LITERALS_LDS(base, size) \ ZSTDGPU_LDS_SIZE(size) \ ZSTDGPU_LDS_BASE(base) \ - ZSTDGPU_LDS_REGION(HuffmanTable, kzstdgpu_MaxCount_HuffmanTableExpandedUInts) + ZSTDGPU_LDS_REGION(HuffmanTable, 1u << (kzstdgpu_MaxCount_HuffmanWeightBits - 1)) #include "zstdgpu_lds_decl_size.h" ZSTDGPU_DECOMPRESS_LITERALS_LDS(0, DecompressLiterals); @@ -2872,17 +2889,33 @@ static void zstdgpu_ShaderEntry_DecompressLiterals(ZSTDGPU_PARAM_INOUT(zstdgpu_D const uint32_t bitsMax = htInfo >> 16; const uint32_t codeTableSize = htInfo & 0xffffu; const uint32_t stateCnt = WaveReadLaneFirst(srt.inHuffmanTableRankIndex[htIndex * kzstdgpu_MaxCount_HuffmanWeightRanks + bitsMax]); + const uint32_t statePairCnt = stateCnt >> 1u; - // Expand Huffman Table - ZSTDGPU_FOR_WORK_ITEMS(stateId, stateCnt, threadId, kzstdgpu_TgSizeX_DecompressLiterals) + // Expand Huffman Table, pack 2 entries into single dword + ZSTDGPU_FOR_WORK_ITEMS(statePairId, statePairCnt, threadId, tgSize) { - const uint32_t symbolIndex = zstdgpu_BinarySearchMasked(srt.inHuffmanTableCodeAndSymbol, htIndex * kzstdgpu_MaxCount_HuffmanWeights, codeTableSize, stateId, 0x00ffffffu); - const uint32_t bitcntIndex = zstdgpu_BinarySearchMasked(srt.inHuffmanTableRankIndex, htIndex * kzstdgpu_MaxCount_HuffmanWeightRanks, bitsMax + 1, stateId, 0xffffffffu) - - htIndex * kzstdgpu_MaxCount_HuffmanWeightRanks; - const uint32_t symbol = srt.inHuffmanTableCodeAndSymbol[symbolIndex] >> 24; - const uint32_t bitcnt = bitsMax - bitcntIndex; + const uint32_t stateId0 = statePairId << 1u; + const uint32_t stateId1 = stateId0 + 1u; + + const uint32_t symbolIndex0 = zstdgpu_BinarySearchMasked(srt.inHuffmanTableCodeAndSymbol, htIndex * kzstdgpu_MaxCount_HuffmanWeights, codeTableSize, stateId0, 0x00ffffffu); + const uint32_t symbolIndex1 = zstdgpu_BinarySearchMasked(srt.inHuffmanTableCodeAndSymbol, htIndex * kzstdgpu_MaxCount_HuffmanWeights, codeTableSize, stateId1, 0x00ffffffu); + + const uint32_t bitcntIndex0 = zstdgpu_BinarySearchMasked(srt.inHuffmanTableRankIndex, htIndex * kzstdgpu_MaxCount_HuffmanWeightRanks, bitsMax + 1, stateId0, 0xffffffffu) + - htIndex * kzstdgpu_MaxCount_HuffmanWeightRanks; + + const uint32_t bitcntIndex1 = zstdgpu_BinarySearchMasked(srt.inHuffmanTableRankIndex, htIndex * kzstdgpu_MaxCount_HuffmanWeightRanks, bitsMax + 1, stateId1, 0xffffffffu) + - htIndex * kzstdgpu_MaxCount_HuffmanWeightRanks; + + const uint32_t symbol0 = srt.inHuffmanTableCodeAndSymbol[symbolIndex0] >> 24; + const uint32_t bitcnt0 = bitsMax - bitcntIndex0; - zstdgpu_LdsStoreU32(GS_HuffmanTable + stateId, (symbol << 16) | bitcnt); + const uint32_t symbol1 = srt.inHuffmanTableCodeAndSymbol[symbolIndex1] >> 24; + const uint32_t bitcnt1 = bitsMax - bitcntIndex1; + + const uint32_t symbolAndBitcnt0 = (symbol0 << 8) | bitcnt0; + const uint32_t symbolAndBitcnt1 = (symbol1 << 8) | bitcnt1; + + zstdgpu_LdsStoreU32(GS_HuffmanTable + statePairId, (symbolAndBitcnt1 << 16) | symbolAndBitcnt0); } GroupMemoryBarrierWithGroupSync(); @@ -2900,6 +2933,7 @@ static void zstdgpu_ShaderEntry_DecompressLiterals(ZSTDGPU_PARAM_INOUT(zstdgpu_D bitsMax, tgSize ); + } // LDS partitioning macro lists for combined Huffman Table Initialisation + Literal Decompression @@ -2909,7 +2943,7 @@ static void zstdgpu_ShaderEntry_DecompressLiterals(ZSTDGPU_PARAM_INOUT(zstdgpu_D ZSTDGPU_LDS_REGION(CodeAndSymbol , kzstdgpu_MaxCount_HuffmanWeights) \ ZSTDGPU_LDS_REGION(PreInit , kzstdgpu_PreInitHuffmanTable_LdsSize) \ ZSTDGPU_LDS_REGION(RankIndex , kzstdgpu_MaxCount_HuffmanWeightRanks) \ - ZSTDGPU_LDS_REGION(HuffmanTable , kzstdgpu_MaxCount_HuffmanTableExpandedUInts) + ZSTDGPU_LDS_REGION(HuffmanTable , 1u << (kzstdgpu_MaxCount_HuffmanWeightBits - 1u)) #include "zstdgpu_lds_decl_size.h" ZSTDGPU_INIT_HUFFMAN_TABLE_AND_DECOMPRESS_LITERALS_LDS(0, InitHuffmanTableAndDecompressLiterals); @@ -2965,16 +2999,30 @@ static void zstdgpu_ShaderEntry_InitHuffmanTable_And_DecompressLiterals(ZSTDGPU_ GroupMemoryBarrierWithGroupSync(); const uint32_t stateCnt = zstdgpu_LdsLoadU32(GS_RankIndex + bitsMax); + const uint32_t statePairCnt = stateCnt >> 1u; // Expand Huffman Table - ZSTDGPU_FOR_WORK_ITEMS(stateId, stateCnt, threadId, kzstdgpu_TgSizeX_DecompressLiterals) + ZSTDGPU_FOR_WORK_ITEMS(statePairId, statePairCnt, threadId, kzstdgpu_TgSizeX_DecompressLiterals) { - const uint32_t symbolIndex = zstdgpu_BinarySearchLds(GS_CodeAndSymbol, 0, codeTableSize, stateId, 0x00ffffffu); - const uint32_t bitcntIndex = zstdgpu_BinarySearchLds(GS_RankIndex, 0, bitsMax + 1, stateId, 0xffffffffu); - const uint32_t symbol = zstdgpu_LdsLoadU32(GS_CodeAndSymbol + symbolIndex) >> 24; - const uint32_t bitcnt = bitsMax - bitcntIndex; + const uint32_t stateId0 = statePairId << 1u; + const uint32_t stateId1 = stateId0 + 1u; + + const uint32_t symbolIndex0 = zstdgpu_BinarySearchLds(GS_CodeAndSymbol, 0, codeTableSize, stateId0, 0x00ffffffu); + const uint32_t symbolIndex1 = zstdgpu_BinarySearchLds(GS_CodeAndSymbol, 0, codeTableSize, stateId1, 0x00ffffffu); + + const uint32_t bitcntIndex0 = zstdgpu_BinarySearchLds(GS_RankIndex, 0, bitsMax + 1, stateId0, 0xffffffffu); + const uint32_t bitcntIndex1 = zstdgpu_BinarySearchLds(GS_RankIndex, 0, bitsMax + 1, stateId1, 0xffffffffu); - zstdgpu_LdsStoreU32(GS_HuffmanTable + stateId, (symbol << 16) | bitcnt); + const uint32_t symbol0 = zstdgpu_LdsLoadU32(GS_CodeAndSymbol + symbolIndex0) >> 24; + const uint32_t symbol1 = zstdgpu_LdsLoadU32(GS_CodeAndSymbol + symbolIndex1) >> 24; + + const uint32_t bitcnt0 = bitsMax - bitcntIndex0; + const uint32_t bitcnt1 = bitsMax - bitcntIndex1; + + const uint32_t symbolAndBitcnt0 = (symbol0 << 8) | bitcnt0; + const uint32_t symbolAndBitcnt1 = (symbol1 << 8) | bitcnt1; + + zstdgpu_LdsStoreU32(GS_HuffmanTable + statePairId, (symbolAndBitcnt1 << 16) | symbolAndBitcnt0); } GroupMemoryBarrierWithGroupSync(); @@ -2999,9 +3047,11 @@ static inline void zstdgpu_SampleHuffmanSymbolAndBitcnt(ZSTDGPU_PARAM_INOUT(uint ZSTDGPU_PARAM_IN(uint32_t) state, ZSTDGPU_PARAM_LDS_IN(uint32_t) GS_HuffmanTable) { - const uint32_t symbolAndBitcnt = zstdgpu_LdsLoadU32(GS_HuffmanTable + state); - symbol = symbolAndBitcnt >> 16; - bitcnt = symbolAndBitcnt & 0xffffu; + const uint32_t statePairId = state >> 1; + const uint32_t stateIdInPair = state & 0x1u; + const uint32_t symbolAndBitcnt = zstdgpu_LdsLoadU32(GS_HuffmanTable + statePairId) >> (stateIdInPair << 4u); + symbol = (symbolAndBitcnt >> 8) & 0xffu; + bitcnt = symbolAndBitcnt & 0xffu; } void zstdgpu_DecompressHuffmanCompressedLiterals(ZSTDGPU_RO_RAW_BUFFER(uint32_t) CompressedData, @@ -3081,6 +3131,175 @@ void zstdgpu_DecompressHuffmanCompressedLiterals(ZSTDGPU_RO_RAW_BUFFER(uint32_t) } } +void zstdgpu_DecompressHuffmanCompressedLiterals_StoreLdsCache(ZSTDGPU_RO_RAW_BUFFER(uint32_t) CompressedData, + ZSTDGPU_RO_BUFFER(uint32_t) LitStreamRemap, + ZSTDGPU_RO_BUFFER(zstdgpu_LitStreamInfo) LitRefs, + ZSTDGPU_RW_TYPED_BUFFER(uint32_t, uint8_t) DecompressedLiterals, + ZSTDGPU_RW_BUFFER(uint32_t) DecompressedLiteralsAsDwords, + ZSTDGPU_PARAM_LDS_IN(uint32_t) GS_HuffmanTable, + ZSTDGPU_PARAM_LDS_INOUT(uint32_t) GS_LiteralStoreCache, + uint32_t groupId, + uint32_t threadId, + uint32_t htGroupStart, + uint32_t htLiteralStart, + uint32_t htLiteralCount, + uint32_t bitsMax, + uint32_t tgSize, + uint32_t streamsPerGroup, + uint32_t cacheDwordsPerStream) +{ + ZSTDGPU_UNUSED(threadId); + const uint32_t maxBitcntMask = (1u << bitsMax) - 1u; + // + // The start of decompression of Huffman-compressed literals + // + const uint32_t thisGroupLiteralStart = (groupId - htGroupStart) * streamsPerGroup; + const uint32_t thisGroupLiteralRemain = zstdgpu_MinU32(htLiteralCount - thisGroupLiteralStart, streamsPerGroup); + + zstdgpu_HuffmanStream stream; + + uint32_t byteAlignedBeg = 0; + uint32_t byteAlignedEnd = 0; + uint32_t dwordAlignedEnd = 0; + uint32_t dwordAlignedBeg = 0; + + ZSTDGPU_BRANCH if (threadId < thisGroupLiteralRemain) + { + const uint32_t literalStreamId = LitStreamRemap[htLiteralStart + thisGroupLiteralStart + threadId]; + zstdgpu_LitStreamInfo compressedLiteral = LitRefs[literalStreamId]; + if (compressedLiteral.dst.size != 0) // derived from block Regenerated_Size + { + zstdgpu_HuffmanStream_InitWithSegment(stream, CompressedData, compressedLiteral.src, bitsMax); + } + + byteAlignedBeg = compressedLiteral.dst.offs; + byteAlignedEnd = byteAlignedBeg + compressedLiteral.dst.size; + dwordAlignedEnd = zstdgpu_MaxU32(byteAlignedBeg, byteAlignedEnd & ~3u); + dwordAlignedBeg = zstdgpu_MinU32((byteAlignedBeg + 3u) & ~3u, dwordAlignedEnd); + } + + uint32_t symbol = 0; + uint32_t bitcnt = 0; + uint32_t state = 0; + + // Handle head bytes (up to 3 bytes before the first dword-aligned address) + ZSTDGPU_BRANCH if (byteAlignedBeg < dwordAlignedBeg) + { + state = zstdgpu_HuffmanStream_RefillAndPeek(stream); + zstdgpu_SampleHuffmanSymbolAndBitcnt(symbol, bitcnt, state, GS_HuffmanTable); + zstdgpu_TypedStoreU8(DecompressedLiterals, byteAlignedBeg ++, symbol); + zstdgpu_HuffmanStream_Consume(stream, bitcnt); + + ZSTDGPU_BRANCH if (byteAlignedBeg < dwordAlignedBeg) + { + state = zstdgpu_HuffmanStream_RefillAndPeek(stream); + zstdgpu_SampleHuffmanSymbolAndBitcnt(symbol, bitcnt, state, GS_HuffmanTable); + zstdgpu_TypedStoreU8(DecompressedLiterals, byteAlignedBeg ++, symbol); + zstdgpu_HuffmanStream_Consume(stream, bitcnt); + + ZSTDGPU_BRANCH if (byteAlignedBeg < dwordAlignedBeg) + { + state = zstdgpu_HuffmanStream_RefillAndPeek(stream); + zstdgpu_SampleHuffmanSymbolAndBitcnt(symbol, bitcnt, state, GS_HuffmanTable); + zstdgpu_TypedStoreU8(DecompressedLiterals, byteAlignedBeg ++, symbol); + zstdgpu_HuffmanStream_Consume(stream, bitcnt); + } + } + } + + const uint32_t kStoreCacheBankCount = 32; + const uint32_t kStoreCacheBankMask = kStoreCacheBankCount - 1u; + + const uint32_t storeCacheThreadOffset = threadId * cacheDwordsPerStream; + + const uint32_t dwordIdxBeg = dwordAlignedBeg >> 2; + const uint32_t dwordIdxEnd = dwordAlignedEnd >> 2; + uint32_t dwordIdx = dwordIdxBeg; + + do + { + const uint32_t dwordIdxBatchBeg = dwordIdx; + const uint32_t dwordIdxBatchEnd = zstdgpu_MinU32(dwordIdx + cacheDwordsPerStream, dwordIdxEnd); + + // Populate LDS cache of at most 'cacheDwordsPerStream' dwords per stream + for (; dwordIdx < dwordIdxBatchEnd; ++dwordIdx) + { + uint32_t dword = 0; + state = zstdgpu_HuffmanStream_RefillAndPeek(stream); + zstdgpu_SampleHuffmanSymbolAndBitcnt(symbol, bitcnt, state, GS_HuffmanTable); + dword |= symbol; + zstdgpu_HuffmanStream_Consume(stream, bitcnt); + + state = zstdgpu_HuffmanStream_RefillAndPeek(stream); + zstdgpu_SampleHuffmanSymbolAndBitcnt(symbol, bitcnt, state, GS_HuffmanTable); + dword |= symbol << 8; + zstdgpu_HuffmanStream_Consume(stream, bitcnt); + + state = zstdgpu_HuffmanStream_RefillAndPeek(stream); + zstdgpu_SampleHuffmanSymbolAndBitcnt(symbol, bitcnt, state, GS_HuffmanTable); + dword |= symbol << 16; + zstdgpu_HuffmanStream_Consume(stream, bitcnt); + + state = zstdgpu_HuffmanStream_RefillAndPeek(stream); + zstdgpu_SampleHuffmanSymbolAndBitcnt(symbol, bitcnt, state, GS_HuffmanTable); + dword |= symbol << 24; + zstdgpu_HuffmanStream_Consume(stream, bitcnt); + + const uint32_t dwordIdxInBatch = dwordIdx - dwordIdxBatchBeg; + + // add 'threadId' to 'dwordIdx' to make sure there's no bank conflicts. + const uint32_t dwordIdxInCache = (dwordIdxInBatch & ~kStoreCacheBankMask) + ((dwordIdxInBatch + threadId) & kStoreCacheBankMask); + zstdgpu_LdsStoreU32(GS_LiteralStoreCache + storeCacheThreadOffset + dwordIdxInCache, dword); + } + // Ensure all threads' LDS cache writes are visible before cooperative read + GroupMemoryBarrierWithGroupSync(); + + // Move at most `cacheDwordsPerStream` dwords from LDS cache to memory using the entire threadgroup. + for (uint32_t i = 0; i < thisGroupLiteralRemain; ++i) + { + const uint32_t dwordCntInBatch = WaveReadLaneAt(dwordIdxBatchEnd - dwordIdxBatchBeg, i); + const uint32_t dstDwordIdx = WaveReadLaneAt(dwordIdxBatchBeg, i); + + ZSTDGPU_FOR_WORK_ITEMS(dwordIdxToStore, dwordCntInBatch, threadId, tgSize) + { + // Inverse of the store swizzle: thread i stored logical dword d at cache position ((d + i) & mask) + const uint32_t dwordIdxInCache = (dwordIdxToStore & ~kStoreCacheBankMask) + ((dwordIdxToStore + i) & kStoreCacheBankMask); + const uint32_t dword = zstdgpu_LdsLoadU32(GS_LiteralStoreCache + i * cacheDwordsPerStream + dwordIdxInCache); + + DecompressedLiteralsAsDwords[dstDwordIdx + dwordIdxToStore] = dword; + } + } + } + while (WaveActiveAnyTrue(dwordIdx < dwordIdxEnd)); + + // Handle tail bytes (up to 3 bytes after the last dword-aligned address) + ZSTDGPU_BRANCH if (dwordAlignedEnd < byteAlignedEnd) + { + uint32_t tailByte = dwordAlignedEnd; + + state = zstdgpu_HuffmanStream_RefillAndPeek(stream); + zstdgpu_SampleHuffmanSymbolAndBitcnt(symbol, bitcnt, state, GS_HuffmanTable); + zstdgpu_TypedStoreU8(DecompressedLiterals, tailByte ++, symbol); + zstdgpu_HuffmanStream_Consume(stream, bitcnt); + + ZSTDGPU_BRANCH if (tailByte < byteAlignedEnd) + { + state = zstdgpu_HuffmanStream_RefillAndPeek(stream); + zstdgpu_SampleHuffmanSymbolAndBitcnt(symbol, bitcnt, state, GS_HuffmanTable); + zstdgpu_TypedStoreU8(DecompressedLiterals, tailByte ++, symbol); + zstdgpu_HuffmanStream_Consume(stream, bitcnt); + + ZSTDGPU_BRANCH if (tailByte < byteAlignedEnd) + { + state = zstdgpu_HuffmanStream_RefillAndPeek(stream); + zstdgpu_SampleHuffmanSymbolAndBitcnt(symbol, bitcnt, state, GS_HuffmanTable); + zstdgpu_TypedStoreU8(DecompressedLiterals, tailByte ++, symbol); + zstdgpu_HuffmanStream_Consume(stream, bitcnt); + } + } + } +} + #ifdef __hlsl_dx_compiler static const uint32_t SEQ_LITERAL_LENGTH_BASELINES[36] = { diff --git a/zstd/zstdgpu/zstdgpu_structs.h b/zstd/zstdgpu/zstdgpu_structs.h index 5809794..bd17a57 100644 --- a/zstd/zstdgpu/zstdgpu_structs.h +++ b/zstd/zstdgpu/zstdgpu_structs.h @@ -158,7 +158,6 @@ static const uint32_t kzstdgpu_MaxCount_HuffmanWeightBits = 16; static const uint32_t kzstdgpu_MaxCount_HuffmanWeightRanks = kzstdgpu_MaxCount_HuffmanWeightBits + 1; static const uint32_t kzstdgpu_MaxCount_HuffmanWeightsOneDigitBits = kzstdgpu_MaxCount_HuffmanWeights / 32; static const uint32_t kzstdgpu_MaxCount_HuffmanWeightsAllDigitBits = kzstdgpu_MaxCount_HuffmanWeightsOneDigitBits * 5; -static const uint32_t kzstdgpu_MaxCount_HuffmanTableExpandedUInts = 2048; static const uint32_t kzstdgpu_MaxCount_FseProbs = 256;