diff --git a/.gitignore b/.gitignore index 5e834f14..18b3bed3 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,9 @@ metals.sbt smoke julia +# Model weight files +*.gguf +.lwjgl/ +# AMD Radeon GPU Analyzer +rga-*/ diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..99fcdf71 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,462 @@ +# AGENTS.md - Cyfra Development Guide + +## Project Overview + +Cyfra is a Scala 3 library for GPU computing that compiles a Scala DSL to SPIR-V and executes it via Vulkan. It supports Linux, Windows, and macOS (with MoltenVK). + +**Key modules:** +- `cyfra-dsl` - Core DSL for GPU computations (values, expressions, bindings, structs) +- `cyfra-compiler` - Compiles DSL to SPIR-V bytecode +- `cyfra-vulkan` - Low-level Vulkan bindings via LWJGL +- `cyfra-runtime` - Runtime execution via VkCyfraRuntime +- `cyfra-core` - GProgram abstraction connecting DSL and runtime +- `cyfra-foton` - High-level GFunction API, image/animation rendering, ray tracing +- `cyfra-fluids` - 3D fluid simulation example +- `cyfra-analytics` - Customer segmentation server (http4s/tapir + GPU) +- `cyfra-fs2` - fs2 streaming interop (GPipe, GCluster) +- `cyfra-llama` - LLM inference with F16/F32 pipelines +- `cyfra-spirv-tools` - SPIR-V validation/optimization/disassembly +- `cyfra-utility` - Logging, utility functions +- `cyfra-e2e-test` - End-to-end tests +- `cyfra-examples` - Example programs + +## Essential Commands + +### Build & Compile +```bash +sbt compile # Compile all modules +sbt "project dsl" compile # Compile specific module +``` + +### Testing +```bash +sbt test # Run all tests (requires Vulkan-capable GPU) +sbt "project fluids" test # Test specific module +sbt "project e2eTest" test # E2E tests (forked, custom JVM options) +``` + +### Formatting +```bash +sbt formatAll # Format all code (scalafmt) +sbt formatCheckAll # Check formatting (CI uses this) +``` + +### Running Examples +```bash +sbt "project examples" run # Run examples +sbt "project analytics" run # Start segmentation server (port 8081) +sbt "project fluids" run # Run fluid simulation +sbt "project llama" run # Run LLM inference (requires model file) +``` + +### Llama Runner +```bash +# From sbt: +sbt "project llama" "run --model models/Llama-3.2-1B-Instruct-f16.gguf -i" +sbt "project llama" "run -m model.gguf --measure -n 128" # Benchmark +``` + +## Code Organization + +``` +io.computenode.cyfra +├── dsl/ # DSL types and abstractions +│ ├── Value.* # Int32, Float32, Float16, Vec3, Vec4, etc. +│ ├── Expression.* # Expression tree nodes +│ ├── algebra/ # ScalarAlgebra, VectorAlgebra (givens) +│ ├── binding/ # GBuffer, GUniform, GShared +│ ├── collections/ # GArray, GArray2D, GSeq +│ ├── control/ # When (conditionals), Pure, Scope +│ ├── gio/ # GIO monad for GPU operations +│ ├── library/ # Functions, Math3D, Color, Random +│ └── struct/ # GStruct (case class → GPU struct) +├── spirv/ # SPIR-V compiler +│ └── compilers/ # DSLCompiler, ExpressionCompiler, etc. +├── core/ # Core abstractions +│ ├── GProgram # GPU program definition +│ ├── GExecution # Execution model +│ ├── GCodec # Type serialization to ByteBuffer +│ ├── layout/Layout # Layout derivation for program layouts +│ └── CyfraRuntime # Runtime trait +├── runtime/ # Vulkan runtime +│ └── VkCyfraRuntime # Main runtime implementation +├── vulkan/ # Vulkan bindings (LWJGL) +├── foton/ # High-level APIs +│ ├── GFunction # Simple Array[A] => Array[B] function +│ ├── rt/ # Ray tracing (Camera, Scene, Shape, Material) +│ └── animation/ # AnimatedFunction, AnimationRenderer +└── llama/ # LLM inference + ├── inference/ # LlamaInference, CPUInference + ├── pipeline/ # LlamaF16Pipeline, LlamaF32Pipeline + ├── programs/f16/ # F16 GPU programs (attention, matmul, etc.) + ├── programs/f32/ # F32 GPU programs + └── gguf/ # GGUF model loading +``` + +## DSL Patterns + +### Importing the DSL +```scala +import io.computenode.cyfra.dsl.{*, given} // All DSL types and givens +import io.computenode.cyfra.core.GCodec.{*, given} // Codecs +import io.computenode.cyfra.core.layout.Layout // Layout derivation +``` + +### Creating a GFunction (High-Level) +```scala +import io.computenode.cyfra.foton.GFunction +import io.computenode.cyfra.runtime.VkCyfraRuntime + +given CyfraRuntime = VkCyfraRuntime() + +val fn: GFunction[GStruct.Empty, Float32, Float32] = GFunction: x => + (x + 1.0f) * (x - 2.0f) + +val result: Array[Float] = fn.run(inputArray) +``` + +### GPU Structs (Case Classes for Uniforms) +```scala +// Define a struct that can be passed as a uniform +case class MyParams( + dt: Float32, + gridSize: Int32, +) extends GStruct[MyParams] + +object MyParams: + given GStructSchema[MyParams] = GStructSchema.derived +``` + +### Program Layout (Buffers and Uniforms) +```scala +// Layout defines all GPU bindings for a program +case class ProgramLayout( + input: GBuffer[Float16], + weight: GBuffer[Float16], + output: GBuffer[Float16], + params: GUniform[MyParams], +) derives Layout +``` + +### GProgram Pattern (Low-Level) +```scala +object MyProgram: + case class Sizes(numElements: Int, featureSize: Int) + + case class ProgramLayout( + input: GBuffer[Float32], + output: GBuffer[Float32], + params: GUniform[MyParams], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + GProgram[Sizes, ProgramLayout]( + // Layout factory: creates buffer specs from params + layout = s => ProgramLayout( + input = GBuffer[Float32](s.numElements), + output = GBuffer[Float32](s.numElements), + params = GUniform[MyParams](), + ), + // Dispatch: determines workgroup count + dispatch = (_, s) => StaticDispatch(((s.numElements + 255) / 256, 1, 1)), + workgroupSize = (256, 1, 1), + ): layout => + // GPU kernel body + val tid = GIO.invocationId + GIO.when(tid < sizes.numElements): + val x = GIO.read[Float32](layout.input, tid) + val p = layout.params.read + GIO.write[Float32](layout.output, tid, x * p.dt) +``` + +### GIO Operations +```scala +// Read from buffer at index +val value = GIO.read[Float32](buffer, index) +val value = buffer.read(index) // alternative syntax + +// Write to buffer (returns GIO for chaining) +GIO.write[Float32](buffer, index, value) + +// Read uniform struct +val params = layout.params.read + +// Get thread IDs +val globalId = GIO.invocationId // Global invocation ID (flattened) +val localId = GIO.localInvocationId // Local ID within workgroup (Vec3) +val workgroupId = GIO.workgroupId // Workgroup ID (Vec3) + +// Conditionals +GIO.when(condition): + // body executed if condition true + +// Barriers +GIO.barrier // Workgroup memory barrier + +// Subgroup operations (warp-level) +val sum = GIO.subgroupAdd(localValue) +val max = GIO.subgroupMax(localValue) + +// Loops +GIO.repeat(iterations): i => + // body with iteration index i + +// Fold loops (accumulate result) +GIO.foldRepeat[Float32](iterations, initialValue): (i, acc) => + // return new accumulator value +``` + +### GSeq - Lazy GPU Sequences +```scala +// Generate sequence: start value + next function +val seq = GSeq.gen[Int32](startIdx, _ + stride) + .limit(numIterations) + .fold(0.0f, (sum: Float32, idx: Int32) => + val value = GIO.read[Float32](buffer, idx) + sum + value + ) + +// With unrolling hint for small fixed loops +GSeq.gen[Int32](0, _ + 1) + .limit(headSize) + .unroll // generates #pragma unroll + .fold(0.0f, (acc, d) => acc + values(d)) + +// Filtering +GSeq.gen[Int32](0, _ + 1) + .limit(maxLen) + .takeWhile(_ < actualLen) + .fold(...) +``` + +### Conditionals +```scala +// When expression (returns value) +when(condition)( + thenValue +).otherwise( + elseValue +) + +// Chained conditions +when(cond1)(val1) + .elseWhen(cond2)(val2) + .otherwise(val3) + +// GIO.when for side effects +GIO.when(tid < numElements): + GIO.write(output, tid, value) +``` + +### Shared Memory (Workgroup Local) +```scala +// Declare shared memory buffer +val sharedMem = GShared[Float32](256) // 256 elements + +// Write to shared memory +sharedMem.write(localIdx, value) + +// Read from shared memory +val v = sharedMem.read(localIdx) + +// Barrier before reading what others wrote +GIO.barrier +``` + +### Vector Operations +```scala +val v1: Vec4[Float32] = vec4(1.0f, 2.0f, 3.0f, 4.0f) +val v2 = v1 * 2.0f + vec4(1.0f, 1.0f, 1.0f, 1.0f) +val dotProduct = v1.dot(v2) +val xyz: Vec3[Float32] = v1.xyz // Swizzle + +// Type conversions +val f16: Float16 = f32Value.asFloat16 +val f32: Float32 = f16Value.asFloat32 +val intVal: Int32 = floatVal.asInt +val floatVal: Float32 = intVal.asFloat +``` + +### Math Functions +```scala +// From io.computenode.cyfra.dsl.library.Functions +sqrt(x) +exp(x) +log(x) +sin(x), cos(x), tan(x) +abs(x) +min(a, b), max(a, b) +clamp(x, minVal, maxVal) +mix(a, b, t) // Linear interpolation +``` + +## Testing Patterns + +### Test Setup +```scala +class MyTest extends munit.FunSuite: + var runtime: VkCyfraRuntime = null + + override def beforeAll(): Unit = + runtime = VkCyfraRuntime() + + override def afterAll(): Unit = + if runtime != null then runtime.close() + + test("description"): + given VkCyfraRuntime = runtime + // test code +``` + +### E2E Tests +- Located in `cyfra-e2e-test/src/test/scala/` +- Run forked with custom JVM options: `-Dorg.lwjgl.system.stackSize=1024` +- Buffer sizes must be multiples of 256 (Vulkan alignment) + +## Style Conventions + +### Scalafmt Configuration +- Max column: 150 +- Scala 3 dialect +- Trailing commas: always +- Scala 3 syntax rewrites enabled +- Run `sbt formatAll` before committing + +### Naming +- GPU programs: `*Program` (e.g., `AdvectionProgram`, `RMSNormProgram`) +- Program layouts: `ProgramLayout` or `*Layout` case class with `derives Layout` +- GPU struct schemas: companion object with `given GStructSchema[T] = GStructSchema.derived` +- F16 vs F32 variants: prefix with `F16*` or `F32*` +- Size parameters: `Sizes` case class with computed properties + +### Code Style +- Use explicit type annotations for public APIs +- Prefer extension methods for DSL operations +- Workgroup sizes typically `(256, 1, 1)` or `(128, 1, 1)` +- WARP_SIZE constant = 32 (matches GPU subgroup size) +- Compile-time constants as vals in enclosing scope, lifted to Int32/Float32 in kernel + +## Important Gotchas + +### Vulkan/GPU Requirements +- Tests require a Vulkan-capable GPU +- Buffer sizes must be multiples of 256 bytes (Vulkan alignment) +- CI runs `formatCheckAll; compile` only (no GPU on CI runners) + +### Runtime Lifecycle +```scala +// Always close runtime +val runtime = VkCyfraRuntime() +try + // use runtime +finally + runtime.close() + +// Or use the helper: +VkCyfraRuntime.using: + // runtime available as given +``` + +### Validation Layers (Development) +```bash +# Enable Vulkan validation layers +-Dio.computenode.cyfra.vulkan.validation=true + +# macOS additional settings +-Dorg.lwjgl.vulkan.libname=libvulkan.1.dylib +-Djava.library.path=$VULKAN_SDK/lib +``` + +### SPIR-V Compilation +- DSL compiles to SPIR-V at runtime via `DSLCompiler.compile()` +- Use `cyfra-spirv-tools` for validation/debugging +- `SpirvDisassembler` outputs human-readable SPIR-V assembly + +### ByteBuffer Handling +- Use `ByteOrder.nativeOrder()` for all GPU buffers +- GCodec handles serialization: `codec.toByteBuffer(buffer, array)` +- Always `rewind()` buffers before passing to GPU + +### Numeric Precision +- F16 programs use F32 for intermediate computations (numerical stability) +- Use `.asFloat32` / `.asFloat16` for conversions +- Subgroup operations (subgroupAdd, subgroupMax) are hardware-accelerated + +### GSeq Limitations +- Must have `.limit(n)` before `.fold()` - infinite streams not supported +- Use `.unroll` only for small, fixed-size loops + +## Module Dependencies + +``` +utility +├── spirvTools +├── vulkan +└── dsl + └── compiler + └── core + └── runtime + ├── foton + │ ├── fluids + │ ├── analytics + │ └── examples + └── fs2interop + └── e2eTest +``` + +## CI/CD + +- **CI**: `sbt "formatCheckAll; compile"` on push to main/tags and PRs to dev +- **Release**: `sbt ci-release` on version tags (v*) +- **JDK**: GraalVM Java 21 (CI), Temurin 21 (release) +- **Artifacts**: Published to Maven Central via sbt-ci-release + +## Documentation + +- Project docs: `docs/` directory (Docusaurus) +- API docs: In-code Scaladoc +- Examples: `cyfra-examples/src/main/scala/` + +## Common Program Patterns + +### Element-wise Operation +```scala +GProgram[Sizes, Layout]( + layout = s => Layout(GBuffer(s.n), GBuffer(s.n)), + dispatch = (_, s) => StaticDispatch(((s.n + 255) / 256, 1, 1)), + workgroupSize = (256, 1, 1), +): layout => + val tid = GIO.invocationId + GIO.when(tid < sizes.n): + val x = GIO.read(layout.input, tid) + GIO.write(layout.output, tid, transform(x)) +``` + +### Reduction with Subgroups +```scala +val localSum = GSeq.gen[Int32](laneId, _ + WARP_SIZE) + .limit(numIterations) + .fold(0.0f, (sum, k) => + when(k < totalSize)(sum + GIO.read(input, k)) + .otherwise(sum) + ) +val totalSum = GIO.subgroupAdd(localSum) +``` + +### Attention Pattern (Q·K→softmax→V) +```scala +// Phase 1: Compute scores +for + scores <- computeQKScores + _ <- GIO.barrier + // Phase 2: Softmax + maxScore <- GIO.pure(GIO.subgroupMax(localMax)) + expSum <- computeExpSum(scores, maxScore) + _ <- GIO.barrier + // Phase 3: Normalize + _ <- normalizeScores(scores, expSum) + _ <- GIO.barrier + // Phase 4: Weighted sum of V + _ <- computeWeightedV(scores, vCache) +yield GStruct.Empty() +``` diff --git a/build.sbt b/build.sbt index eaa97261..121fee15 100644 --- a/build.sbt +++ b/build.sbt @@ -94,6 +94,9 @@ lazy val tapirSettings = Seq( lazy val utility = (project in file("cyfra-utility")) .settings(commonSettings) + .settings( + libraryDependencies += "net.java.dev.jna" % "jna" % "5.14.0", + ) lazy val spirvTools = (project in file("cyfra-spirv-tools")) .settings(commonSettings) @@ -153,10 +156,15 @@ lazy val e2eTest = (project in file("cyfra-e2e-test")) .settings(publish / skip := true) .dependsOn(runtime, fs2interop, foton) +lazy val llama = (project in file("cyfra-llama")) + .settings(commonSettings, runnerSettings) + .settings(publish / skip := true) + .dependsOn(runtime, dsl, core, utility) + lazy val root = (project in file(".")) .settings(name := "Cyfra") .settings(publish / skip := true) - .aggregate(compiler, dsl, foton, core, runtime, vulkan, examples, fs2interop, fluids, analytics, utility, spirvTools, vscode) + .aggregate(compiler, dsl, foton, core, runtime, vulkan, examples, fs2interop, fluids, analytics, utility, spirvTools, vscode, llama) e2eTest / Test / javaOptions ++= Seq("-Dorg.lwjgl.system.stackSize=1024", "-DuniqueLibraryNames=true") diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala index 96490071..bf3235bb 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Context.scala @@ -4,7 +4,7 @@ import io.computenode.cyfra.dsl.binding.{GBuffer, GUniform} import io.computenode.cyfra.dsl.macros.FnCall.FnIdentifier import io.computenode.cyfra.spirv.SpirvConstants.HEADER_REFS_TOP import io.computenode.cyfra.spirv.compilers.FunctionCompiler.SprivFunction -import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.ArrayBufferBlock +import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.{ArrayBufferBlock, SharedBlock} import izumi.reflect.Tag import izumi.reflect.macrortti.LightTypeTag @@ -13,15 +13,24 @@ private[cyfra] case class Context( funPointerTypeMap: Map[Int, Int] = Map(), uniformPointerMap: Map[Int, Int] = Map(), inputPointerMap: Map[Int, Int] = Map(), + workgroupPointerMap: Map[Int, Int] = Map(), funcTypeMap: Map[(LightTypeTag, List[LightTypeTag]), Int] = Map(), voidTypeRef: Int = -1, voidFuncTypeRef: Int = -1, workerIndexRef: Int = -1, + localInvocationIndexRef: Int = -1, + localInvocationIdRef: Int = -1, + workgroupIdRef: Int = -1, + numWorkgroupsRef: Int = -1, + subgroupIdRef: Int = -1, + subgroupLocalInvocationIdRef: Int = -1, + subgroupSizeRef: Int = -1, uniformVarRefs: Map[GUniform[?], Int] = Map.empty, bindingToStructType: Map[Int, Int] = Map.empty, constRefs: Map[(Tag[?], Any), Int] = Map(), exprRefs: Map[Int, Int] = Map(), bufferBlocks: Map[GBuffer[?], ArrayBufferBlock] = Map(), + sharedVarRefs: Map[Int, SharedBlock] = Map(), nextResultId: Int = HEADER_REFS_TOP, nextBinding: Int = 0, exprNames: Map[Int, String] = Map(), diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala index 1f8c4cb6..cbce430c 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/Opcodes.scala @@ -66,7 +66,7 @@ private[cyfra] object Opcodes: override def toString: String = s"%$result" val MagicNumber = Code("MagicNumber", 0x07230203) - val Version = Code("Version", 0x00010000) + val Version = Code("Version", 0x00010300) // SPIR-V 1.3 for GroupNonUniform val Revision = Code("Revision", 8) val Generator = Code("Generator", 0) val OpCodeMask = Code("OpCodeMask", 0xffff) @@ -516,6 +516,20 @@ private[cyfra] object Opcodes: val Reduce = Code("Reduce", 0) val InclusiveScan = Code("InclusiveScan", 1) val ExclusiveScan = Code("ExclusiveScan", 2) + val ClusteredReduce = Code("ClusteredReduce", 3) + + object MemorySemantics: + val None = Code("None", 0x0) + val Acquire = Code("Acquire", 0x2) + val Release = Code("Release", 0x4) + val AcquireRelease = Code("AcquireRelease", 0x8) + val SequentiallyConsistent = Code("SequentiallyConsistent", 0x10) + val UniformMemory = Code("UniformMemory", 0x40) + val SubgroupMemory = Code("SubgroupMemory", 0x80) + val WorkgroupMemory = Code("WorkgroupMemory", 0x100) + val CrossWorkgroupMemory = Code("CrossWorkgroupMemory", 0x200) + val AtomicCounterMemory = Code("AtomicCounterMemory", 0x400) + val ImageMemory = Code("ImageMemory", 0x800) object KernelEnqueueFlags: val NoWait = Code("NoWait", 0) @@ -589,6 +603,14 @@ private[cyfra] object Opcodes: val SubgroupDispatch = Code("SubgroupDispatch", 58) val NamedBarrier = Code("NamedBarrier", 59) val PipeStorage = Code("PipeStorage", 60) + val GroupNonUniform = Code("GroupNonUniform", 61) + val GroupNonUniformVote = Code("GroupNonUniformVote", 62) + val GroupNonUniformArithmetic = Code("GroupNonUniformArithmetic", 63) + val GroupNonUniformBallot = Code("GroupNonUniformBallot", 64) + val GroupNonUniformShuffle = Code("GroupNonUniformShuffle", 65) + val GroupNonUniformShuffleRelative = Code("GroupNonUniformShuffleRelative", 66) + val GroupNonUniformClustered = Code("GroupNonUniformClustered", 67) + val GroupNonUniformQuad = Code("GroupNonUniformQuad", 68) val SubgroupBallotKHR = Code("SubgroupBallotKHR", 4423) val DrawParameters = Code("DrawParameters", 4427) val SubgroupVoteKHR = Code("SubgroupVoteKHR", 4431) @@ -949,6 +971,42 @@ private[cyfra] object Opcodes: val OpSubgroupImageBlockReadINTEL = Code("OpSubgroupImageBlockReadINTEL", 5577) val OpSubgroupImageBlockWriteINTEL = Code("OpSubgroupImageBlockWriteINTEL", 5578) + // GroupNonUniform operations (Vulkan 1.1+) + val OpGroupNonUniformElect = Code("OpGroupNonUniformElect", 333) + val OpGroupNonUniformAll = Code("OpGroupNonUniformAll", 334) + val OpGroupNonUniformAny = Code("OpGroupNonUniformAny", 335) + val OpGroupNonUniformAllEqual = Code("OpGroupNonUniformAllEqual", 336) + val OpGroupNonUniformBroadcast = Code("OpGroupNonUniformBroadcast", 337) + val OpGroupNonUniformBroadcastFirst = Code("OpGroupNonUniformBroadcastFirst", 338) + val OpGroupNonUniformBallot = Code("OpGroupNonUniformBallot", 339) + val OpGroupNonUniformInverseBallot = Code("OpGroupNonUniformInverseBallot", 340) + val OpGroupNonUniformBallotBitExtract = Code("OpGroupNonUniformBallotBitExtract", 341) + val OpGroupNonUniformBallotBitCount = Code("OpGroupNonUniformBallotBitCount", 342) + val OpGroupNonUniformBallotFindLSB = Code("OpGroupNonUniformBallotFindLSB", 343) + val OpGroupNonUniformBallotFindMSB = Code("OpGroupNonUniformBallotFindMSB", 344) + val OpGroupNonUniformShuffle = Code("OpGroupNonUniformShuffle", 345) + val OpGroupNonUniformShuffleXor = Code("OpGroupNonUniformShuffleXor", 346) + val OpGroupNonUniformShuffleUp = Code("OpGroupNonUniformShuffleUp", 347) + val OpGroupNonUniformShuffleDown = Code("OpGroupNonUniformShuffleDown", 348) + val OpGroupNonUniformIAdd = Code("OpGroupNonUniformIAdd", 349) + val OpGroupNonUniformFAdd = Code("OpGroupNonUniformFAdd", 350) + val OpGroupNonUniformIMul = Code("OpGroupNonUniformIMul", 351) + val OpGroupNonUniformFMul = Code("OpGroupNonUniformFMul", 352) + val OpGroupNonUniformSMin = Code("OpGroupNonUniformSMin", 353) + val OpGroupNonUniformUMin = Code("OpGroupNonUniformUMin", 354) + val OpGroupNonUniformFMin = Code("OpGroupNonUniformFMin", 355) + val OpGroupNonUniformSMax = Code("OpGroupNonUniformSMax", 356) + val OpGroupNonUniformUMax = Code("OpGroupNonUniformUMax", 357) + val OpGroupNonUniformFMax = Code("OpGroupNonUniformFMax", 358) + val OpGroupNonUniformBitwiseAnd = Code("OpGroupNonUniformBitwiseAnd", 359) + val OpGroupNonUniformBitwiseOr = Code("OpGroupNonUniformBitwiseOr", 360) + val OpGroupNonUniformBitwiseXor = Code("OpGroupNonUniformBitwiseXor", 361) + val OpGroupNonUniformLogicalAnd = Code("OpGroupNonUniformLogicalAnd", 362) + val OpGroupNonUniformLogicalOr = Code("OpGroupNonUniformLogicalOr", 363) + val OpGroupNonUniformLogicalXor = Code("OpGroupNonUniformLogicalXor", 364) + val OpGroupNonUniformQuadBroadcast = Code("OpGroupNonUniformQuadBroadcast", 365) + val OpGroupNonUniformQuadSwap = Code("OpGroupNonUniformQuadSwap", 366) + object GlslOp: val Round = Code("Round", 1) val RoundEven = Code("RoundEven", 2) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala index ec3c4d0b..b6a98052 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvConstants.scala @@ -17,5 +17,12 @@ private[cyfra] object SpirvConstants: val GL_GLOBAL_INVOCATION_ID_REF = 5 val GL_WORKGROUP_SIZE_REF = 6 val DEBUG_PRINTF_REF = 7 + val GL_LOCAL_INVOCATION_ID_REF = 8 + val GL_LOCAL_INVOCATION_INDEX_REF = 9 + val GL_WORKGROUP_ID_REF = 10 + val GL_NUM_WORKGROUPS_REF = 11 + val GL_SUBGROUP_ID_REF = 12 + val GL_SUBGROUP_LOCAL_INVOCATION_ID_REF = 13 + val GL_SUBGROUP_SIZE_REF = 14 - val HEADER_REFS_TOP = 8 + val HEADER_REFS_TOP = 15 diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala index 7adeb972..c4fac1eb 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/SpirvTypes.scala @@ -10,6 +10,7 @@ private[cyfra] object SpirvTypes: val Int32Tag = summon[Tag[Int32]] val UInt32Tag = summon[Tag[UInt32]] + val Float16Tag = summon[Tag[Float16]] val Float32Tag = summon[Tag[Float32]] val GBooleanTag = summon[Tag[GBoolean]] val Vec2TagWithoutArgs = summon[Tag[Vec2[?]]].tag.withoutArgs @@ -22,6 +23,7 @@ private[cyfra] object SpirvTypes: val LInt32Tag = Int32Tag.tag val LUInt32Tag = UInt32Tag.tag + val LFloat16Tag = Float16Tag.tag val LFloat32Tag = Float32Tag.tag val LGBooleanTag = GBooleanTag.tag val LVec2TagWithoutArgs = Vec2TagWithoutArgs @@ -36,9 +38,38 @@ private[cyfra] object SpirvTypes: type Vec3C[T <: Value] = Vec3[T] type Vec4C[T <: Value] = Vec4[T] + /** Convert Float32 to Float16 (half precision) bits. + * Uses round-to-nearest-even rounding mode. + */ + def floatToFloat16(f: Float): Int = { + val bits = java.lang.Float.floatToIntBits(f) + val sign = (bits >>> 16) & 0x8000 + val exponent = ((bits >>> 23) & 0xFF) - 127 + 15 + val mantissa = bits & 0x007FFFFF + + if (exponent <= 0) { + // Denormalized or zero + if (exponent < -10) { + sign // Zero + } else { + // Denormalized + val m = mantissa | 0x00800000 + val shifted = m >>> (1 - exponent) + sign | (shifted >>> 13) + } + } else if (exponent >= 31) { + // Infinity or NaN + sign | 0x7C00 | (if (mantissa != 0) 0x200 else 0) + } else { + // Normalized + sign | (exponent << 10) | (mantissa >>> 13) + } + } + def scalarTypeDefInsn(tag: Tag[?], typeDefIndex: Int) = tag match case Int32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(1))) case UInt32Tag => Instruction(Op.OpTypeInt, List(ResultRef(typeDefIndex), IntWord(32), IntWord(0))) + case Float16Tag => Instruction(Op.OpTypeFloat, List(ResultRef(typeDefIndex), IntWord(16))) case Float32Tag => Instruction(Op.OpTypeFloat, List(ResultRef(typeDefIndex), IntWord(32))) case GBooleanTag => Instruction(Op.OpTypeBool, List(ResultRef(typeDefIndex))) @@ -50,6 +81,7 @@ private[cyfra] object SpirvTypes: def typeStride(tag: LightTypeTag): Int = tag match case LInt32Tag => 4 case LUInt32Tag => 4 + case LFloat16Tag => 2 case LFloat32Tag => 4 case LGBooleanTag => 4 case v if v <:< LVecTag => @@ -63,6 +95,14 @@ private[cyfra] object SpirvTypes: IntWord(value.asInstanceOf[Int]) case t if t == UInt32Tag => IntWord(value.asInstanceOf[Int]) + case t if t == Float16Tag => + val fl = value match + case fl: Float => fl + case dl: Double => dl.toFloat + case il: Int => il.toFloat + // Convert Float32 to Float16 (half precision) + val f16Bits = floatToFloat16(fl) + Word(intToBytes(f16Bits & 0xFFFF).reverse.toArray) case t if t == Float32Tag => val fl = value match case fl: Float => fl @@ -71,7 +111,7 @@ private[cyfra] object SpirvTypes: Word(intToBytes(java.lang.Float.floatToIntBits(fl)).reverse.toArray) def defineScalarTypes(types: List[Tag[?]], context: Context): (List[Words], Context) = - val basicTypes = List(Int32Tag, Float32Tag, UInt32Tag, GBooleanTag) + val basicTypes = List(Int32Tag, Float16Tag, Float32Tag, UInt32Tag, GBooleanTag) (basicTypes ::: types).distinct.foldLeft((List[Words](), context)) { case ((words, ctx), valType) => val typeDefIndex = ctx.nextResultId val code = List( diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala index 8bdafb24..e8613bb0 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/DSLCompiler.scala @@ -4,7 +4,7 @@ import io.computenode.cyfra.* import io.computenode.cyfra.dsl.* import io.computenode.cyfra.dsl.Expression.E import io.computenode.cyfra.dsl.Value.Scalar -import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform, WriteBuffer, WriteUniform} +import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GShared, GUniform, ReadShared, WriteBuffer, WriteShared, WriteUniform} import io.computenode.cyfra.dsl.gio.GIO import io.computenode.cyfra.dsl.struct.GStruct.* import io.computenode.cyfra.dsl.struct.GStructSchema @@ -34,19 +34,36 @@ private[cyfra] object DSLCompiler: getAllExprsFlattened(tail, getAllExprsFlattened(v.tree, visitDetached) ::: acc, visitDetached) case GIO.FlatMap(v, n) :: tail => getAllExprsFlattened(v :: n :: tail, acc, visitDetached) - case GIO.Repeat(n, gio) :: tail => + case GIO.Repeat(n, gio, _) :: tail => val nAllExprs = getAllExprsFlattened(n.tree, visitDetached) getAllExprsFlattened(gio :: tail, nAllExprs ::: acc, visitDetached) - case WriteBuffer(_, index, value) :: tail => - val indexAllExprs = getAllExprsFlattened(index.tree, visitDetached) - val valueAllExprs = getAllExprsFlattened(value.tree, visitDetached) - getAllExprsFlattened(tail, indexAllExprs ::: valueAllExprs ::: acc, visitDetached) + case (wb: WriteBuffer[?]) :: tail => + val indexAllExprs = getAllExprsFlattened(wb.index.tree, visitDetached) + val valueAllExprs = getAllExprsFlattened(wb.value.tree, visitDetached) + // Also collect the underlying Empty expression for FoldRepeat body lookup + val underlyingAllExprs = getAllExprsFlattened(wb.underlying.tree, visitDetached) + getAllExprsFlattened(tail, indexAllExprs ::: valueAllExprs ::: underlyingAllExprs ::: acc, visitDetached) case WriteUniform(_, value) :: tail => val valueAllExprs = getAllExprsFlattened(value.tree, visitDetached) getAllExprsFlattened(tail, valueAllExprs ::: acc, visitDetached) case GIO.Printf(_, args*) :: tail => val argsAllExprs = args.flatMap(a => getAllExprsFlattened(a.tree, visitDetached)).toList getAllExprsFlattened(tail, argsAllExprs ::: acc, visitDetached) + case GIO.WorkgroupBarrier :: tail => + getAllExprsFlattened(tail, acc, visitDetached) + case (ws: WriteShared[?]) :: tail => + val indexAllExprs = getAllExprsFlattened(ws.index.tree, visitDetached) + val valueAllExprs = getAllExprsFlattened(ws.value.tree, visitDetached) + // Also collect the underlying Empty expression for FoldRepeat body lookup + val underlyingAllExprs = getAllExprsFlattened(ws.underlying.tree, visitDetached) + getAllExprsFlattened(tail, indexAllExprs ::: valueAllExprs ::: underlyingAllExprs ::: acc, visitDetached) + case GIO.FoldRepeat(n, init, body, _, _) :: tail => + val nAllExprs = getAllExprsFlattened(n.tree, visitDetached) + val initAllExprs = getAllExprsFlattened(init.tree, visitDetached) + getAllExprsFlattened(body :: tail, nAllExprs ::: initAllExprs ::: acc, visitDetached) + case GIO.ConditionalWhen(cond, body) :: tail => + val condAllExprs = getAllExprsFlattened(cond.tree, visitDetached) + getAllExprsFlattened(body :: tail, condAllExprs ::: acc, visitDetached) // TODO: Not traverse same fn scopes for each fn call private def getAllExprsFlattened(root: E[?], visitDetached: Boolean): List[E[?]] = @@ -71,22 +88,88 @@ private[cyfra] object DSLCompiler: allScopesCache(root.treeid) = result result + private def getAllShared(pending: List[GIO[?]], acc: Map[Int, GShared[?]]): Map[Int, GShared[?]] = + pending match + case Nil => acc + case GIO.FlatMap(v, n) :: tail => + getAllShared(v :: n :: tail, acc) + case GIO.Repeat(_, gio, _) :: tail => + getAllShared(gio :: tail, acc) + case GIO.FoldRepeat(_, _, gio, _, _) :: tail => + getAllShared(gio :: tail, acc) + case GIO.ConditionalWhen(_, body) :: tail => + getAllShared(body :: tail, acc) + case WriteShared(buffer, _, _) :: tail => + val impl = buffer.asInstanceOf[GShared.GSharedImpl[?]] + getAllShared(tail, acc + (impl.sharedId -> buffer)) + case _ :: tail => getAllShared(tail, acc) + + private def getAllSharedFromExprs(exprs: List[E[?]], acc: Map[Int, GShared[?]]): Map[Int, GShared[?]] = + exprs.foldLeft(acc): + case (a, ReadShared(buffer, _)) => + val impl = buffer.asInstanceOf[GShared.GSharedImpl[?]] + a + (impl.sharedId -> buffer) + case (a, _) => a + + private def createSharedVariables(sharedBuffers: Map[Int, GShared[?]], ctx: Context): (List[Words], Context) = + sharedBuffers.foldLeft((List.empty[Words], ctx)): + case ((insnsAcc, currentCtx), (sharedId, buffer)) => + val elementTypeRef = currentCtx.valueTypeMap(buffer.tag.tag) + val arraySizeConstRef = currentCtx.constRefs.getOrElse( + (Int32Tag, buffer.size), + throw new IllegalStateException(s"Missing constant for shared array size ${buffer.size}"), + ) + + // SPIR-V shared memory structure: + // 1. Array type: OpTypeArray %arrayType %elementType %size + // 2. Pointer to array: OpTypePointer %ptrArrayType Workgroup %arrayType + // 3. Variable: OpVariable %ptrArrayType %var Workgroup + // 4. Pointer to element: OpTypePointer %ptrElemType Workgroup %elementType (for OpAccessChain) + val arrayTypeRef = currentCtx.nextResultId + val ptrArrayTypeRef = currentCtx.nextResultId + 1 + val varRef = currentCtx.nextResultId + 2 + val ptrElemTypeRef = currentCtx.nextResultId + 3 + + val insns = List( + Instruction(Op.OpTypeArray, List(ResultRef(arrayTypeRef), ResultRef(elementTypeRef), ResultRef(arraySizeConstRef))), + Instruction(Op.OpTypePointer, List(ResultRef(ptrArrayTypeRef), StorageClass.Workgroup, ResultRef(arrayTypeRef))), + Instruction(Op.OpVariable, List(ResultRef(ptrArrayTypeRef), ResultRef(varRef), StorageClass.Workgroup)), + Instruction(Op.OpTypePointer, List(ResultRef(ptrElemTypeRef), StorageClass.Workgroup, ResultRef(elementTypeRef))), + ) + + val block = SharedBlock(arrayTypeRef, varRef, ptrElemTypeRef) + val newCtx = currentCtx.copy( + nextResultId = currentCtx.nextResultId + 4, + sharedVarRefs = currentCtx.sharedVarRefs + (sharedId -> block), + workgroupPointerMap = currentCtx.workgroupPointerMap + (elementTypeRef -> ptrElemTypeRef), + ) + (insnsAcc ::: insns, newCtx) + // So far only used for printf private def getAllStrings(pending: List[GIO[?]], acc: Set[String]): Set[String] = pending match case Nil => acc case GIO.FlatMap(v, n) :: tail => getAllStrings(v :: n :: tail, acc) - case GIO.Repeat(_, gio) :: tail => + case GIO.Repeat(_, gio, _) :: tail => getAllStrings(gio :: tail, acc) + case GIO.FoldRepeat(_, _, gio, _, _) :: tail => + getAllStrings(gio :: tail, acc) + case GIO.ConditionalWhen(_, body) :: tail => + getAllStrings(body :: tail, acc) case GIO.Printf(format, _*) :: tail => getAllStrings(tail, acc + format) case _ :: tail => getAllStrings(tail, acc) - def compile(bodyIo: GIO[?], bindings: List[GBinding[?]]): ByteBuffer = + def compile(bodyIo: GIO[?], bindings: List[GBinding[?]], workgroupSize: (Int, Int, Int) = (256, 1, 1)): ByteBuffer = val allExprs = getAllExprsFlattened(List(bodyIo), Nil, visitDetached = true) val typesInCode = allExprs.map(_.tag).distinct - val allTypes = (typesInCode ::: bindings.map(_.tag)).distinct + + val sharedFromGio = getAllShared(List(bodyIo), Map.empty) + val sharedFromExprs = getAllSharedFromExprs(allExprs, sharedFromGio) + val sharedTypes = sharedFromExprs.values.map(_.tag).toList + + val allTypes = (typesInCode ::: bindings.map(_.tag) ::: sharedTypes).distinct def scalarTypes = allTypes.filter(_.tag <:< summon[Tag[Scalar]].tag) val (typeDefs, typedContext) = defineScalarTypes(scalarTypes, Context.initialContext) val allStrings = getAllStrings(List(bodyIo), Set.empty) @@ -108,17 +191,28 @@ private[cyfra] object DSLCompiler: val (decorations, uniformDefs, uniformContext) = initAndDecorateBuffers(buffersWithIndices, structNamesCtx) val (uniformStructDecorations, uniformStructInsns, uniformStructContext) = createAndInitUniformBlocks(uniformsWithIndices, uniformContext) val blockNames = getBlockNames(uniformContext, uniforms) - val (inputDefs, inputContext) = createInvocationId(uniformStructContext) + val (inputDefs, inputContext) = createInvocationId(uniformStructContext, workgroupSize) + + val sharedSizeConsts = sharedFromExprs.values.map(s => (Int32Tag, s.size)).toList val (constDefs, constCtx) = defineConstants(allExprs, inputContext) - val (varDefs, varCtx) = defineVarNames(constCtx) + val (sharedConstDefs, constCtxWithShared) = sharedSizeConsts.foldLeft((List.empty[Words], constCtx)): + case ((insnsAcc, ctx), const) if ctx.constRefs.contains(const) => (insnsAcc, ctx) + case ((insnsAcc, ctx), const) => + val insn = Instruction(Op.OpConstant, List(ResultRef(ctx.valueTypeMap(const._1.tag)), ResultRef(ctx.nextResultId), IntWord(const._2))) + val newCtx = ctx.copy(constRefs = ctx.constRefs + (const -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) + (insnsAcc :+ insn, newCtx) + + val (sharedDefs, ctxWithShared) = createSharedVariables(sharedFromExprs, constCtxWithShared) + + val (varDefs, varCtx) = defineVarNames(ctxWithShared) val (main, ctxAfterMain) = compileMain(bodyIo, varCtx) val (fnTypeDefs, fnDefs, ctxWithFnDefs) = compileFunctions(ctxAfterMain) val nameDecorations = getNameDecorations(ctxWithFnDefs) val code: List[Words] = - SpirvProgramCompiler.headers ::: stringDefs ::: blockNames ::: nameDecorations ::: structNames ::: SpirvProgramCompiler.workgroupDecorations ::: + SpirvProgramCompiler.headers(workgroupSize) ::: stringDefs ::: blockNames ::: nameDecorations ::: structNames ::: SpirvProgramCompiler.workgroupDecorations ::: decorations ::: uniformStructDecorations ::: typeDefs ::: structDefs ::: fnTypeDefs ::: uniformDefs ::: uniformStructInsns ::: inputDefs ::: - constDefs ::: varDefs ::: main ::: fnDefs + constDefs ::: sharedConstDefs ::: sharedDefs ::: varDefs ::: main ::: fnDefs val fullCode = code.map: case WordVariable(name) if name == BOUND_VARIABLE => IntWord(ctxWithFnDefs.nextResultId) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala index 6e859bd3..b34d2292 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/ExpressionCompiler.scala @@ -29,6 +29,33 @@ private[cyfra] object ExpressionCompiler: case _: Div[?] => (Op.OpSDiv, Op.OpFDiv) case _: Mod[?] => (Op.OpSMod, Op.OpFMod) + private def compileSubgroupOp( + expr: E[?], + value: Value.Scalar, + op: SubgroupOp, + spirvOp: Code, + ctx: Context, + ): (List[Instruction], Context) = + val scopeId = ctx.constRefs((Int32Tag, Scope.Subgroup.opcode)) + val groupOpCode = op match + case SubgroupOp.Reduce => GroupOperation.Reduce + case SubgroupOp.InclusiveScan => GroupOperation.InclusiveScan + case SubgroupOp.ExclusiveScan => GroupOperation.ExclusiveScan + val instructions = List( + Instruction( + spirvOp, + List( + ResultRef(ctx.valueTypeMap(expr.tag.tag)), + ResultRef(ctx.nextResultId), + ResultRef(scopeId), + groupOpCode, + ResultRef(ctx.exprRefs(value.treeid)), + ), + ), + ) + val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) + (instructions, updatedContext) + private def compileBinaryOpExpression(bexpr: BinaryOpExpression[?], ctx: Context): (List[Instruction], Context) = val tpe = bexpr.tag val typeRef = ctx.valueTypeMap(tpe.tag) @@ -52,12 +79,18 @@ private[cyfra] object ExpressionCompiler: val tpe = cexpr.tag val typeRef = ctx.valueTypeMap(tpe.tag) val tfOpcode = (cexpr.fromTag, cexpr) match - case (from, _: ToFloat32[?]) if from.tag =:= Int32Tag.tag => Op.OpConvertSToF - case (from, _: ToFloat32[?]) if from.tag =:= UInt32Tag.tag => Op.OpConvertUToF - case (from, _: ToInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToS - case (from, _: ToUInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToU - case (from, _: ToInt32[?]) if from.tag =:= UInt32Tag.tag => Op.OpBitcast - case (from, _: ToUInt32[?]) if from.tag =:= Int32Tag.tag => Op.OpBitcast + case (from, _: ToFloat16[?]) if from.tag =:= Float32Tag.tag => Op.OpFConvert + case (from, _: ToFloat16[?]) if from.tag =:= Int32Tag.tag => Op.OpConvertSToF + case (from, _: ToFloat16[?]) if from.tag =:= UInt32Tag.tag => Op.OpConvertUToF + case (from, _: ToFloat32[?]) if from.tag =:= Float16Tag.tag => Op.OpFConvert + case (from, _: ToFloat32[?]) if from.tag =:= Int32Tag.tag => Op.OpConvertSToF + case (from, _: ToFloat32[?]) if from.tag =:= UInt32Tag.tag => Op.OpConvertUToF + case (from, _: ToInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToS + case (from, _: ToInt32[?]) if from.tag =:= Float16Tag.tag => Op.OpConvertFToS + case (from, _: ToUInt32[?]) if from.tag =:= Float32Tag.tag => Op.OpConvertFToU + case (from, _: ToUInt32[?]) if from.tag =:= Float16Tag.tag => Op.OpConvertFToU + case (from, _: ToInt32[?]) if from.tag =:= UInt32Tag.tag => Op.OpBitcast + case (from, _: ToUInt32[?]) if from.tag =:= Int32Tag.tag => Op.OpBitcast val instructions = List(Instruction(tfOpcode, List(ResultRef(typeRef), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(cexpr.a.treeid))))) val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cexpr.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) (instructions, updatedContext) @@ -109,12 +142,136 @@ private[cyfra] object ExpressionCompiler: case w @ InvocationId => (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.workerIndexRef))) + case w @ LocalInvocationIndex => + (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.localInvocationIndexRef))) + + case w @ LocalInvocationId => + (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.localInvocationIdRef))) + + case w @ WorkgroupId => + (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.workgroupIdRef))) + + case w @ NumWorkgroups => + (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.numWorkgroupsRef))) + + case w @ SubgroupId => + (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.subgroupIdRef))) + + case w @ SubgroupLocalInvocationId => + (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.subgroupLocalInvocationIdRef))) + + case w @ SubgroupSize => + (Nil, ctx.copy(exprRefs = ctx.exprRefs + (w.treeid -> ctx.subgroupSizeRef))) + + case sg @ SubgroupAddI(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformIAdd, ctx) + + case sg @ SubgroupAddF(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformFAdd, ctx) + + case sg @ SubgroupAddF16(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformFAdd, ctx) + + case sg @ SubgroupMinI(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformSMin, ctx) + + case sg @ SubgroupMinF(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformFMin, ctx) + + case sg @ SubgroupMinF16(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformFMin, ctx) + + case sg @ SubgroupMaxI(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformSMax, ctx) + + case sg @ SubgroupMaxF(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformFMax, ctx) + + case sg @ SubgroupMaxF16(v, op) => + compileSubgroupOp(sg, v, op, Op.OpGroupNonUniformFMax, ctx) + + case sg @ SubgroupBroadcast(v, lane) => + val scopeId = ctx.constRefs((Int32Tag, Scope.Subgroup.opcode)) + val instructions = List( + Instruction( + Op.OpGroupNonUniformBroadcast, + List( + ResultRef(ctx.valueTypeMap(sg.tag.tag)), + ResultRef(ctx.nextResultId), + ResultRef(scopeId), + ResultRef(ctx.exprRefs(v.treeid)), + ResultRef(ctx.exprRefs(lane.treeid)), + ), + ), + ) + val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (sg.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) + (instructions, updatedContext) + + case sg @ SubgroupBroadcastFirst(v) => + val scopeId = ctx.constRefs((Int32Tag, Scope.Subgroup.opcode)) + val instructions = List( + Instruction( + Op.OpGroupNonUniformBroadcastFirst, + List( + ResultRef(ctx.valueTypeMap(sg.tag.tag)), + ResultRef(ctx.nextResultId), + ResultRef(scopeId), + ResultRef(ctx.exprRefs(v.treeid)), + ), + ), + ) + val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (sg.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) + (instructions, updatedContext) + + case sg @ SubgroupShuffle(v, lane) => + val scopeId = ctx.constRefs((Int32Tag, Scope.Subgroup.opcode)) + val instructions = List( + Instruction( + Op.OpGroupNonUniformShuffle, + List( + ResultRef(ctx.valueTypeMap(sg.tag.tag)), + ResultRef(ctx.nextResultId), + ResultRef(scopeId), + ResultRef(ctx.exprRefs(v.treeid)), + ResultRef(ctx.exprRefs(lane.treeid)), + ), + ), + ) + val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (sg.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) + (instructions, updatedContext) + + case sg @ SubgroupShuffleXor(v, mask) => + val scopeId = ctx.constRefs((Int32Tag, Scope.Subgroup.opcode)) + val instructions = List( + Instruction( + Op.OpGroupNonUniformShuffleXor, + List( + ResultRef(ctx.valueTypeMap(sg.tag.tag)), + ResultRef(ctx.nextResultId), + ResultRef(scopeId), + ResultRef(ctx.exprRefs(v.treeid)), + ResultRef(ctx.exprRefs(mask.treeid)), + ), + ), + ) + val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (sg.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) + (instructions, updatedContext) + case d @ ReadUniform(u) => (Nil, ctx.copy(exprRefs = ctx.exprRefs + (d.treeid -> ctx.uniformVarRefs(u)))) case c: ConvertExpression[?, ?] => compileConvertExpression(c, ctx) + case cvf @ ConvertVec4F16ToF32(v) => + // Convert Vec4[Float16] to Vec4[Float32] using OpFConvert + val vec4F32TypeRef = ctx.valueTypeMap(cvf.tag.tag) + val instructions = List( + Instruction(Op.OpFConvert, List(ResultRef(vec4F32TypeRef), ResultRef(ctx.nextResultId), ResultRef(ctx.exprRefs(v.treeid)))) + ) + val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (cvf.treeid -> ctx.nextResultId), nextResultId = ctx.nextResultId + 1) + (instructions, updatedContext) + case b: BinaryOpExpression[?] => compileBinaryOpExpression(b, ctx) @@ -306,6 +463,76 @@ private[cyfra] object ExpressionCompiler: val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2) (instructions, updatedContext) + case rb: ReadBufferVec4[?] => + // Read 4 consecutive elements and construct a Vec4 + // This generates: 4x (OpIAdd + OpAccessChain + OpLoad) + OpCompositeConstruct + val buffer = rb.buffer + val baseIdx = rb.index + val elemTypeRef = ctx.valueTypeMap(buffer.tag.tag) + val int32TypeRef = ctx.valueTypeMap(summon[Tag[Int32]].tag) + val vec4TypeRef = ctx.valueTypeMap(rb.tag.tag) + val ptrTypeRef = ctx.uniformPointerMap(elemTypeRef) + val blockVarRef = ctx.bufferBlocks(buffer).blockVarRef + val const0Ref = ctx.constRefs((Int32Tag, 0)) + val const1Ref = ctx.constRefs((Int32Tag, 1)) + val const2Ref = ctx.constRefs((Int32Tag, 2)) + val const3Ref = ctx.constRefs((Int32Tag, 3)) + val baseIdxRef = ctx.exprRefs(baseIdx.treeid) + + var rid = ctx.nextResultId + val idx0 = baseIdxRef // baseIdx + 0 + val idx1Ref = rid; rid += 1 // baseIdx + 1 + val idx2Ref = rid; rid += 1 // baseIdx + 2 + val idx3Ref = rid; rid += 1 // baseIdx + 3 + val ptr0Ref = rid; rid += 1 + val ptr1Ref = rid; rid += 1 + val ptr2Ref = rid; rid += 1 + val ptr3Ref = rid; rid += 1 + val val0Ref = rid; rid += 1 + val val1Ref = rid; rid += 1 + val val2Ref = rid; rid += 1 + val val3Ref = rid; rid += 1 + val resultRef = rid; rid += 1 + + val instructions = List( + // Compute indices: baseIdx+1, baseIdx+2, baseIdx+3 + Instruction(Op.OpIAdd, List(ResultRef(int32TypeRef), ResultRef(idx1Ref), ResultRef(baseIdxRef), ResultRef(const1Ref))), + Instruction(Op.OpIAdd, List(ResultRef(int32TypeRef), ResultRef(idx2Ref), ResultRef(baseIdxRef), ResultRef(const2Ref))), + Instruction(Op.OpIAdd, List(ResultRef(int32TypeRef), ResultRef(idx3Ref), ResultRef(baseIdxRef), ResultRef(const3Ref))), + // Access chain for each element + Instruction(Op.OpAccessChain, List(ResultRef(ptrTypeRef), ResultRef(ptr0Ref), ResultRef(blockVarRef), ResultRef(const0Ref), ResultRef(idx0))), + Instruction(Op.OpAccessChain, List(ResultRef(ptrTypeRef), ResultRef(ptr1Ref), ResultRef(blockVarRef), ResultRef(const0Ref), ResultRef(idx1Ref))), + Instruction(Op.OpAccessChain, List(ResultRef(ptrTypeRef), ResultRef(ptr2Ref), ResultRef(blockVarRef), ResultRef(const0Ref), ResultRef(idx2Ref))), + Instruction(Op.OpAccessChain, List(ResultRef(ptrTypeRef), ResultRef(ptr3Ref), ResultRef(blockVarRef), ResultRef(const0Ref), ResultRef(idx3Ref))), + // Load each element + Instruction(Op.OpLoad, List(ResultRef(elemTypeRef), ResultRef(val0Ref), ResultRef(ptr0Ref))), + Instruction(Op.OpLoad, List(ResultRef(elemTypeRef), ResultRef(val1Ref), ResultRef(ptr1Ref))), + Instruction(Op.OpLoad, List(ResultRef(elemTypeRef), ResultRef(val2Ref), ResultRef(ptr2Ref))), + Instruction(Op.OpLoad, List(ResultRef(elemTypeRef), ResultRef(val3Ref), ResultRef(ptr3Ref))), + // Construct Vec4 + Instruction(Op.OpCompositeConstruct, List(ResultRef(vec4TypeRef), ResultRef(resultRef), ResultRef(val0Ref), ResultRef(val1Ref), ResultRef(val2Ref), ResultRef(val3Ref))), + ) + val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> resultRef), nextResultId = rid) + (instructions, updatedContext) + + case ReadShared(buffer, i) => + val sharedId = buffer.asInstanceOf[GShared.GSharedImpl[?]].sharedId + val sharedBlock = ctx.sharedVarRefs(sharedId) + val instructions = List( + Instruction( + Op.OpAccessChain, + List( + ResultRef(sharedBlock.pointerTypeRef), + ResultRef(ctx.nextResultId), + ResultRef(sharedBlock.varRef), + ResultRef(ctx.exprRefs(i.treeid)), + ), + ), + Instruction(Op.OpLoad, List(IntWord(ctx.valueTypeMap(buffer.tag.tag)), ResultRef(ctx.nextResultId + 1), ResultRef(ctx.nextResultId))), + ) + val updatedContext = ctx.copy(exprRefs = ctx.exprRefs + (expr.treeid -> (ctx.nextResultId + 1)), nextResultId = ctx.nextResultId + 2) + (instructions, updatedContext) + case when: WhenExpr[?] => compileWhen(when, ctx) diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala index 11adc24c..66ac0b82 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GIOCompiler.scala @@ -1,112 +1,91 @@ package io.computenode.cyfra.spirv.compilers +import io.computenode.cyfra.dsl.Expression.E import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.gio.GIO.{ConditionalWhen, CurrentFoldRepeatAcc, CurrentRepeatIndex, FlatMap, FoldRepeat, Printf, Pure, Repeat, WorkgroupBarrier} +import io.computenode.cyfra.dsl.binding.{WriteBuffer, WriteShared} import io.computenode.cyfra.spirv.Context import io.computenode.cyfra.spirv.Opcodes.* -import io.computenode.cyfra.dsl.binding.* -import io.computenode.cyfra.dsl.gio.GIO.CurrentRepeatIndex import io.computenode.cyfra.spirv.SpirvConstants.{DEBUG_PRINTF_REF, TYPE_VOID_REF} -import io.computenode.cyfra.spirv.SpirvTypes.{GBooleanTag, Int32Tag, LInt32Tag} +import io.computenode.cyfra.spirv.SpirvTypes.{GBooleanTag, Int32Tag} + +import scala.collection.mutable object GIOCompiler: def compileGio(gio: GIO[?], ctx: Context, acc: List[Words] = Nil): (List[Words], Context) = gio match - case GIO.Pure(v) => + case Pure(v) => val (insts, updatedCtx) = ExpressionCompiler.compileBlock(v.tree, ctx) (acc ::: insts, updatedCtx) - case WriteBuffer(buffer, index, value) => + case wb @ WriteBuffer(buffer, index, value) => val (valueInsts, ctxWithValue) = ExpressionCompiler.compileBlock(value.tree, ctx) val (indexInsts, ctxWithIndex) = ExpressionCompiler.compileBlock(index.tree, ctxWithValue) - + // Compile the underlying Empty to register it for FoldRepeat body lookup + val (underlyingInsts, ctxWithUnderlying) = ExpressionCompiler.compileBlock(wb.underlying.tree, ctxWithIndex) val insns = List( Instruction( Op.OpAccessChain, List( - ResultRef(ctxWithIndex.uniformPointerMap(ctxWithIndex.valueTypeMap(buffer.tag.tag))), - ResultRef(ctxWithIndex.nextResultId), - ResultRef(ctxWithIndex.bufferBlocks(buffer).blockVarRef), - ResultRef(ctxWithIndex.constRefs((Int32Tag, 0))), - ResultRef(ctxWithIndex.exprRefs(index.tree.treeid)), + ResultRef(ctxWithUnderlying.uniformPointerMap(ctxWithUnderlying.valueTypeMap(buffer.tag.tag))), + ResultRef(ctxWithUnderlying.nextResultId), + ResultRef(ctxWithUnderlying.bufferBlocks(buffer).blockVarRef), + ResultRef(ctxWithUnderlying.constRefs((Int32Tag, 0))), + ResultRef(ctxWithUnderlying.exprRefs(index.tree.treeid)), ), ), - Instruction(Op.OpStore, List(ResultRef(ctxWithIndex.nextResultId), ResultRef(ctxWithIndex.exprRefs(value.tree.treeid)))), + Instruction(Op.OpStore, List(ResultRef(ctxWithUnderlying.nextResultId), ResultRef(ctxWithUnderlying.exprRefs(value.tree.treeid)))), ) - val updatedCtx = ctxWithIndex.copy(nextResultId = ctxWithIndex.nextResultId + 1) - (acc ::: indexInsts ::: valueInsts ::: insns, updatedCtx) + val updatedCtx = ctxWithUnderlying.copy(nextResultId = ctxWithUnderlying.nextResultId + 1) + // valueInsts before indexInsts: value compiled first, may define exprs index uses + (acc ::: valueInsts ::: indexInsts ::: underlyingInsts ::: insns, updatedCtx) - case GIO.FlatMap(v, n) => + case FlatMap(v, n) => val (vInsts, ctxAfterV) = compileGio(v, ctx, acc) compileGio(n, ctxAfterV, vInsts) - case GIO.Repeat(n, f) => - // Compile 'n' first (so we can use its id in the comparison) - val (nInsts, ctxWithN) = ExpressionCompiler.compileBlock(n.tree, ctx) - - // Types and constants - val intTy = ctxWithN.valueTypeMap(Int32Tag.tag) - val boolTy = ctxWithN.valueTypeMap(GBooleanTag.tag) - val zeroId = ctxWithN.constRefs((Int32Tag, 0)) - val oneId = ctxWithN.constRefs((Int32Tag, 1)) - val nId = ctxWithN.exprRefs(n.tree.treeid) - - // Reserve ids for blocks and results - val baseId = ctxWithN.nextResultId - val preHeaderId = baseId - val headerId = baseId + 1 - val bodyId = baseId + 2 - val continueId = baseId + 3 - val mergeId = baseId + 4 - val phiId = baseId + 5 - val cmpId = baseId + 6 - val addId = baseId + 7 - - // Bind CurrentRepeatIndex to the phi result for body compilation - val bodyCtx = ctxWithN.copy(nextResultId = baseId + 8, exprRefs = ctxWithN.exprRefs + (CurrentRepeatIndex.treeid -> phiId)) - val (bodyInsts, ctxAfterBody) = compileGio(f, bodyCtx) // ← Capture the context after body compilation - - // Preheader: close current block and jump to header through a dedicated block - val preheader = List( - Instruction(Op.OpBranch, List(ResultRef(preHeaderId))), - Instruction(Op.OpLabel, List(ResultRef(preHeaderId))), - Instruction(Op.OpBranch, List(ResultRef(headerId))), - ) + case r @ Repeat(n, f, unroll) => + compileRepeat(n, f, unroll, ctx, acc) - // Header: OpPhi first, then compute condition, then OpLoopMerge and the terminating branch - val header = List( - Instruction(Op.OpLabel, List(ResultRef(headerId))), - // OpPhi must be first in the block - Instruction( - Op.OpPhi, - List(ResultRef(intTy), ResultRef(phiId), ResultRef(zeroId), ResultRef(preHeaderId), ResultRef(addId), ResultRef(continueId)), - ), - // cmp = (counter < n) - Instruction(Op.OpSLessThan, List(ResultRef(boolTy), ResultRef(cmpId), ResultRef(phiId), ResultRef(nId))), - // OpLoopMerge must be the second-to-last instruction, before the terminating branch - Instruction(Op.OpLoopMerge, List(ResultRef(mergeId), ResultRef(continueId), LoopControlMask.MaskNone)), - Instruction(Op.OpBranchConditional, List(ResultRef(cmpId), ResultRef(bodyId), ResultRef(mergeId))), - ) - - val bodyBlk = List(Instruction(Op.OpLabel, List(ResultRef(bodyId)))) ::: bodyInsts ::: List(Instruction(Op.OpBranch, List(ResultRef(continueId)))) + case fr: FoldRepeat[?] => + compileFoldRepeat(fr, ctx, acc) - val contBlk = List( - Instruction(Op.OpLabel, List(ResultRef(continueId))), - Instruction(Op.OpIAdd, List(ResultRef(intTy), ResultRef(addId), ResultRef(phiId), ResultRef(oneId))), - Instruction(Op.OpBranch, List(ResultRef(headerId))), + case WorkgroupBarrier => + val scopeId = ctx.constRefs((Int32Tag, Scope.Workgroup.opcode)) + val semanticsId = ctx.constRefs((Int32Tag, MemorySemantics.WorkgroupMemory.opcode | MemorySemantics.AcquireRelease.opcode)) + val barrierInsn = Instruction( + Op.OpControlBarrier, + List(ResultRef(scopeId), ResultRef(scopeId), ResultRef(semanticsId)), ) + (acc ::: List(barrierInsn), ctx) - val mergeBlk = List(Instruction(Op.OpLabel, List(ResultRef(mergeId)))) + case ConditionalWhen(cond, body) => + compileConditionalWhen(cond, body, ctx, acc) - // Use the highest nextResultId to avoid ID collisions - val finalNextId = math.max(ctxAfterBody.nextResultId, addId + 1) // ← Use ctxAfterBody.nextResultId - // Use ctxWithN as base to prevent loop-local values from being referenced outside - val finalCtx = ctxWithN.copy(nextResultId = finalNextId) - - (acc ::: nInsts ::: preheader ::: header ::: bodyBlk ::: contBlk ::: mergeBlk, finalCtx) + case WriteShared(buffer, index, value) => + val sharedId = buffer.asInstanceOf[io.computenode.cyfra.dsl.binding.GShared.GSharedImpl[?]].sharedId + val (valueInsts, ctxWithValue) = ExpressionCompiler.compileBlock(value.tree, ctx) + val (indexInsts, ctxWithIndex) = ExpressionCompiler.compileBlock(index.tree, ctxWithValue) + val sharedBlock = ctxWithIndex.sharedVarRefs(sharedId) + val insns = List( + Instruction( + Op.OpAccessChain, + List( + ResultRef(sharedBlock.pointerTypeRef), + ResultRef(ctxWithIndex.nextResultId), + ResultRef(sharedBlock.varRef), + ResultRef(ctxWithIndex.exprRefs(index.tree.treeid)), + ), + ), + Instruction(Op.OpStore, List(ResultRef(ctxWithIndex.nextResultId), ResultRef(ctxWithIndex.exprRefs(value.tree.treeid)))), + ) + val updatedCtx = ctxWithIndex.copy(nextResultId = ctxWithIndex.nextResultId + 1) + // valueInsts before indexInsts: value compiled first, may define exprs index uses + (acc ::: valueInsts ::: indexInsts ::: insns, updatedCtx) - case GIO.Printf(format, args*) => + case Printf(format, args*) => val (argsInsts, ctxAfterArgs) = args.foldLeft((List.empty[Words], ctx)) { case ((instsAcc, cAcc), arg) => val (argInsts, cAfterArg) = ExpressionCompiler.compileBlock(arg.tree, cAcc) (instsAcc ::: argInsts, cAfterArg) @@ -123,3 +102,309 @@ object GIOCompiler: ) ::: argResults, ) (acc ::: argsInsts ::: List(printf), ctxAfterArgs.copy(nextResultId = ctxAfterArgs.nextResultId + 1)) + + private def compileRepeat( + n: io.computenode.cyfra.dsl.Value.Int32, + f: GIO[?], + unroll: Boolean, + ctx: Context, + acc: List[Words], + ): (List[Words], Context) = + val (nInsts, ctxWithN) = ExpressionCompiler.compileBlock(n.tree, ctx) + + // Hoist loop-invariant expressions before the loop + // Only hoist expressions that don't depend on loop variables or control flow + val bodyExprs = collectExpressionsMap(f) + val loopDependent = findLoopDependentExprs(bodyExprs, CurrentRepeatIndex.treeid) + val scopeDependent = findScopeDependentExprs(bodyExprs) + // Filter out loop-dependent and scope-dependent (When, etc.) expressions + val invariantExprs = bodyExprs.values.filter { e => + !loopDependent.contains(e.treeid) && !scopeDependent.contains(e.treeid) + }.toList.sortBy(_.treeid) // Sort by treeid to respect definition order + val (invariantInsts, ctxWithInvariants) = invariantExprs.foldLeft((List.empty[Words], ctxWithN)) { + case ((instsAcc, ctxAcc), expr) => + val (insts, newCtx) = ExpressionCompiler.compileBlock(expr, ctxAcc) + (instsAcc ::: insts, newCtx) + } + + val intTy = ctxWithInvariants.valueTypeMap(Int32Tag.tag) + val boolTy = ctxWithInvariants.valueTypeMap(GBooleanTag.tag) + val zeroId = ctxWithInvariants.constRefs((Int32Tag, 0)) + val oneId = ctxWithInvariants.constRefs((Int32Tag, 1)) + val nId = ctxWithInvariants.exprRefs(n.tree.treeid) + + val baseId = ctxWithInvariants.nextResultId + val preHeaderId = baseId + val headerId = baseId + 1 + val bodyId = baseId + 2 + val continueId = baseId + 3 + val mergeId = baseId + 4 + val phiId = baseId + 5 + val cmpId = baseId + 6 + val addId = baseId + 7 + + val bodyCtx = ctxWithInvariants.copy( + nextResultId = baseId + 8, + exprRefs = ctxWithInvariants.exprRefs + (CurrentRepeatIndex.treeid -> phiId), + ) + val (bodyInsts, ctxAfterBody) = compileGio(f, bodyCtx) + + val preheader = List( + Instruction(Op.OpBranch, List(ResultRef(preHeaderId))), + Instruction(Op.OpLabel, List(ResultRef(preHeaderId))), + Instruction(Op.OpBranch, List(ResultRef(headerId))), + ) + + val header = List( + Instruction(Op.OpLabel, List(ResultRef(headerId))), + Instruction( + Op.OpPhi, + List(ResultRef(intTy), ResultRef(phiId), ResultRef(zeroId), ResultRef(preHeaderId), ResultRef(addId), ResultRef(continueId)), + ), + Instruction(Op.OpSLessThan, List(ResultRef(boolTy), ResultRef(cmpId), ResultRef(phiId), ResultRef(nId))), + Instruction(Op.OpLoopMerge, List(ResultRef(mergeId), ResultRef(continueId), + if unroll then LoopControlMask.Unroll else LoopControlMask.MaskNone)), + Instruction(Op.OpBranchConditional, List(ResultRef(cmpId), ResultRef(bodyId), ResultRef(mergeId))), + ) + + val bodyBlk = + List(Instruction(Op.OpLabel, List(ResultRef(bodyId)))) ::: + bodyInsts ::: + List(Instruction(Op.OpBranch, List(ResultRef(continueId)))) + + val contBlk = List( + Instruction(Op.OpLabel, List(ResultRef(continueId))), + Instruction(Op.OpIAdd, List(ResultRef(intTy), ResultRef(addId), ResultRef(phiId), ResultRef(oneId))), + Instruction(Op.OpBranch, List(ResultRef(headerId))), + ) + + val mergeBlk = List(Instruction(Op.OpLabel, List(ResultRef(mergeId)))) + + val finalNextId = math.max(ctxAfterBody.nextResultId, addId + 1) + val finalCtx = ctxWithInvariants.copy(nextResultId = finalNextId) + + (acc ::: nInsts ::: invariantInsts ::: preheader ::: header ::: bodyBlk ::: contBlk ::: mergeBlk, finalCtx) + + /** Compiles foldRepeat - a loop with an accumulator that can contain barriers. */ + private def compileFoldRepeat( + fr: FoldRepeat[?], + ctx: Context, + acc: List[Words], + ): (List[Words], Context) = + val n = fr.n + val init = fr.init + val body = fr.body + val accTreeId = fr.accTreeId + + // Compile n and init + val (nInsts, ctxWithN) = ExpressionCompiler.compileBlock(n.tree, ctx) + val (initInsts, ctxWithInit) = ExpressionCompiler.compileBlock(init.tree, ctxWithN) + + // Hoist loop-invariant expressions before the loop + // Only hoist expressions that don't depend on loop variables + val bodyExprs = collectExpressionsMap(body) + val loopDependent = findLoopDependentExprs(bodyExprs, CurrentRepeatIndex.treeid) + accTreeId + val scopeDependent = findScopeDependentExprs(bodyExprs) + // Filter out loop-dependent and scope-dependent (When, etc.) expressions + val invariantExprs = bodyExprs.values.filter { e => + !loopDependent.contains(e.treeid) && !scopeDependent.contains(e.treeid) + }.toList.sortBy(_.treeid) // Sort by treeid to respect definition order + val (invariantInsts, ctxWithInvariants) = invariantExprs.foldLeft((List.empty[Words], ctxWithInit)) { + case ((instsAcc, ctxAcc), expr) => + val (insts, newCtx) = ExpressionCompiler.compileBlock(expr, ctxAcc) + (instsAcc ::: insts, newCtx) + } + + val intTy = ctxWithInvariants.valueTypeMap(Int32Tag.tag) + val accTy = ctxWithInvariants.valueTypeMap(init.tree.tag.tag) + val boolTy = ctxWithInvariants.valueTypeMap(GBooleanTag.tag) + val zeroId = ctxWithInvariants.constRefs((Int32Tag, 0)) + val oneId = ctxWithInvariants.constRefs((Int32Tag, 1)) + val nId = ctxWithInvariants.exprRefs(n.tree.treeid) + val initId = ctxWithInvariants.exprRefs(init.tree.treeid) + + val baseId = ctxWithInvariants.nextResultId + val preHeaderId = baseId + val headerId = baseId + 1 + val bodyId = baseId + 2 + val continueId = baseId + 3 + val mergeId = baseId + 4 + val iterPhiId = baseId + 5 // loop counter phi + val accPhiId = baseId + 6 // accumulator phi + val cmpId = baseId + 7 + val addId = baseId + 8 + + // Setup context for body compilation with both loop counter and accumulator + val bodyCtx = ctxWithInvariants.copy( + nextResultId = baseId + 9, + exprRefs = ctxWithInvariants.exprRefs + + (CurrentRepeatIndex.treeid -> iterPhiId) + + (accTreeId -> accPhiId), + ) + + val (bodyInsts, ctxAfterBody) = compileGio(body, bodyCtx) + val bodyResultId = ctxAfterBody.exprRefs(body.underlying.tree.treeid) + + val preheader = List( + Instruction(Op.OpBranch, List(ResultRef(preHeaderId))), + Instruction(Op.OpLabel, List(ResultRef(preHeaderId))), + Instruction(Op.OpBranch, List(ResultRef(headerId))), + ) + + val header = List( + Instruction(Op.OpLabel, List(ResultRef(headerId))), + // Phi for loop counter + Instruction( + Op.OpPhi, + List(ResultRef(intTy), ResultRef(iterPhiId), ResultRef(zeroId), ResultRef(preHeaderId), ResultRef(addId), ResultRef(continueId)), + ), + // Phi for accumulator + Instruction( + Op.OpPhi, + List(ResultRef(accTy), ResultRef(accPhiId), ResultRef(initId), ResultRef(preHeaderId), ResultRef(bodyResultId), ResultRef(continueId)), + ), + Instruction(Op.OpSLessThan, List(ResultRef(boolTy), ResultRef(cmpId), ResultRef(iterPhiId), ResultRef(nId))), + Instruction(Op.OpLoopMerge, List(ResultRef(mergeId), ResultRef(continueId), + if fr.unroll then LoopControlMask.Unroll else LoopControlMask.MaskNone)), + Instruction(Op.OpBranchConditional, List(ResultRef(cmpId), ResultRef(bodyId), ResultRef(mergeId))), + ) + + val bodyBlk = + List(Instruction(Op.OpLabel, List(ResultRef(bodyId)))) ::: + bodyInsts ::: + List(Instruction(Op.OpBranch, List(ResultRef(continueId)))) + + val contBlk = List( + Instruction(Op.OpLabel, List(ResultRef(continueId))), + Instruction(Op.OpIAdd, List(ResultRef(intTy), ResultRef(addId), ResultRef(iterPhiId), ResultRef(oneId))), + Instruction(Op.OpBranch, List(ResultRef(headerId))), + ) + + val mergeBlk = List(Instruction(Op.OpLabel, List(ResultRef(mergeId)))) + + val finalNextId = math.max(ctxAfterBody.nextResultId, addId + 1) + // The result of foldRepeat is the final accumulator value (accPhiId after merge) + // We need to map both: + // 1. The accumulator phantom treeid (for expressions that reference the accumulator) + // 2. The body result treeid (for the FlatMap chain to work correctly) + val finalCtx = ctxAfterBody.copy( + nextResultId = finalNextId, + exprRefs = ctxAfterBody.exprRefs + + (accTreeId -> accPhiId) + + (body.underlying.tree.treeid -> accPhiId), + ) + + (acc ::: nInsts ::: initInsts ::: invariantInsts ::: preheader ::: header ::: bodyBlk ::: contBlk ::: mergeBlk, finalCtx) + + /** Compiles ConditionalWhen - a proper if-then construct for conditional execution. + * Generates OpSelectionMerge + OpBranchConditional instead of a loop. + */ + private def compileConditionalWhen( + cond: io.computenode.cyfra.dsl.GBoolean, + body: GIO[?], + ctx: Context, + acc: List[Words], + ): (List[Words], Context) = + // Compile the condition + val (condInsts, ctxWithCond) = ExpressionCompiler.compileBlock(cond.tree, ctx) + val condId = ctxWithCond.exprRefs(cond.tree.treeid) + + val baseId = ctxWithCond.nextResultId + val headerLabelId = baseId + val thenLabelId = baseId + 1 + val mergeLabelId = baseId + 2 + + // Setup context for body compilation + val bodyCtx = ctxWithCond.copy(nextResultId = baseId + 3) + val (bodyInsts, ctxAfterBody) = compileGio(body, bodyCtx) + + // Header block: branch to header label, then do selection + val headerBlock = List( + Instruction(Op.OpBranch, List(ResultRef(headerLabelId))), + Instruction(Op.OpLabel, List(ResultRef(headerLabelId))), + Instruction(Op.OpSelectionMerge, List(ResultRef(mergeLabelId), SelectionControlMask.MaskNone)), + Instruction(Op.OpBranchConditional, List(ResultRef(condId), ResultRef(thenLabelId), ResultRef(mergeLabelId))), + ) + + // Then block: execute body, then branch to merge + val thenBlock = List(Instruction(Op.OpLabel, List(ResultRef(thenLabelId)))) ::: + bodyInsts ::: + List(Instruction(Op.OpBranch, List(ResultRef(mergeLabelId)))) + + // Merge block: continuation point + val mergeBlock = List(Instruction(Op.OpLabel, List(ResultRef(mergeLabelId)))) + + val finalNextId = math.max(ctxAfterBody.nextResultId, mergeLabelId + 1) + // Use ctxWithCond's exprRefs but updated nextResultId - same pattern as compileRepeat + val finalCtx = ctxWithCond.copy(nextResultId = finalNextId) + + (acc ::: condInsts ::: headerBlock ::: thenBlock ::: mergeBlock, finalCtx) + + /** Finds the CurrentFoldRepeatAcc phantom expression in a GIO tree. */ + private def findFoldRepeatAcc(gio: GIO[?]): Option[CurrentFoldRepeatAcc[?]] = + def findInExpr(expr: E[?]): Option[CurrentFoldRepeatAcc[?]] = + expr match + case acc: CurrentFoldRepeatAcc[?] => Some(acc) + case _ => expr.exprDependencies.flatMap(findInExpr).headOption + + def findInGio(g: GIO[?]): Option[CurrentFoldRepeatAcc[?]] = g match + case Pure(v) => findInExpr(v.tree) + case FlatMap(v, n) => findInGio(v).orElse(findInGio(n)) + case Repeat(n, body, _) => findInExpr(n.tree).orElse(findInGio(body)) + case FoldRepeat(n, init, b, _, _) => findInExpr(n.tree).orElse(findInExpr(init.tree)).orElse(findInGio(b)) + case ConditionalWhen(cond, body) => findInExpr(cond.tree).orElse(findInGio(body)) + case WriteBuffer(_, i, v) => findInExpr(i.tree).orElse(findInExpr(v.tree)) + case WriteShared(_, i, v) => findInExpr(i.tree).orElse(findInExpr(v.tree)) + case Printf(_, args*) => args.flatMap(a => findInExpr(a.tree)).headOption + case WorkgroupBarrier => None + + findInGio(gio) + + private def collectExpressionsMap(gio: GIO[?]): Map[Int, E[?]] = + val result = mutable.Map[Int, E[?]]() + + def collectFromExpr(expr: E[?]): Unit = + if !result.contains(expr.treeid) then + result += (expr.treeid -> expr) + expr.exprDependencies.foreach(collectFromExpr) + + def collectFromGio(g: GIO[?]): Unit = g match + case Pure(v) => collectFromExpr(v.tree) + case FlatMap(v, n) => collectFromGio(v); collectFromGio(n) + case Repeat(n, body, _) => collectFromExpr(n.tree); collectFromGio(body) + case FoldRepeat(n, init, body, _, _) => collectFromExpr(n.tree); collectFromExpr(init.tree); collectFromGio(body) + case ConditionalWhen(cond, body) => collectFromExpr(cond.tree); collectFromGio(body) + case WriteBuffer(_, i, v) => collectFromExpr(i.tree); collectFromExpr(v.tree) + case WriteShared(_, i, v) => collectFromExpr(i.tree); collectFromExpr(v.tree) + case Printf(_, args*) => args.foreach(a => collectFromExpr(a.tree)) + case WorkgroupBarrier => () // No expressions to collect + + collectFromGio(gio) + result.toMap + + private def findLoopDependentExprs(exprsMap: Map[Int, E[?]], loopVarId: Int): Set[Int] = + val dependent = mutable.Set[Int](loopVarId) + var changed = true + while changed do + changed = false + exprsMap.values.foreach: expr => + if !dependent.contains(expr.treeid) then + // Check if any dependency's treeid is in dependent set + if expr.exprDependencies.exists(dep => dependent.contains(dep.treeid)) then + dependent += expr.treeid + changed = true + dependent.toSet + + /** Find expressions that depend (transitively) on scope-introducing expressions like When. */ + private def findScopeDependentExprs(exprsMap: Map[Int, E[?]]): Set[Int] = + val scopeExprs = exprsMap.values.filter(_.introducedScopes.nonEmpty).map(_.treeid).toSet + val dependent = mutable.Set.from(scopeExprs) + var changed = true + while changed do + changed = false + exprsMap.values.foreach: expr => + if !dependent.contains(expr.treeid) then + if expr.exprDependencies.exists(dep => dependent.contains(dep.treeid)) then + dependent += expr.treeid + changed = true + dependent.toSet diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala index e635c4c5..57299b1f 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/GSeqCompiler.scala @@ -68,7 +68,7 @@ private[cyfra] object GSeqCompiler: ::: List( // acc = nextAcc Instruction(Op.OpStore, List(ResultRef(resultVar), ResultRef(reduceCtx.exprRefs(foldFnExpr.treeid)))), ) - (instructions, ctx.joinNested(reduceCtx)) + (instructions, context.joinNested(reduceCtx)) case (op, dExpr) :: tail => op match @@ -176,7 +176,8 @@ private[cyfra] object GSeqCompiler: ), Instruction(Op.OpBranch, List(ResultRef(loopBack))), Instruction(Op.OpLabel, List(ResultRef(loopBack))), - Instruction(Op.OpLoopMerge, List(ResultRef(mergeBlock), ResultRef(continueTarget), LoopControlMask.MaskNone)), + Instruction(Op.OpLoopMerge, List(ResultRef(mergeBlock), ResultRef(continueTarget), + if fold.unroll then LoopControlMask.Unroll else LoopControlMask.MaskNone)), Instruction(Op.OpBranch, List(ResultRef(postLoopMergeLabel))), Instruction(Op.OpLabel, List(ResultRef(postLoopMergeLabel))), Instruction(Op.OpLoad, List(ResultRef(boolType), ResultRef(shouldTakeInCheck), ResultRef(shouldTakeVar))), diff --git a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala index bd4e469c..0be33578 100644 --- a/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala +++ b/cyfra-compiler/src/main/scala/io/computenode/cyfra/spirv/compilers/SpirvProgramCompiler.scala @@ -22,31 +22,72 @@ private[cyfra] object SpirvProgramCompiler: case _ => false def compileMain(bodyIo: GIO[?], ctx: Context): (List[Words], Context) = + val int32TypeRef = ctx.valueTypeMap(Int32Tag.tag) + val vec3Int32TypeRef = ctx.valueTypeMap(summon[Tag[Vec3[Int32]]].tag) + val int32PtrInputRef = ctx.inputPointerMap(int32TypeRef) + val vec3Int32PtrInputRef = ctx.inputPointerMap(vec3Int32TypeRef) + val zeroConstRef = ctx.constRefs(Int32Tag, 0) val init = List( Instruction(Op.OpFunction, List(ResultRef(ctx.voidTypeRef), ResultRef(MAIN_FUNC_REF), SamplerAddressingMode.None, ResultRef(VOID_FUNC_TYPE_REF))), Instruction(Op.OpLabel, List(ResultRef(ctx.nextResultId))), ) - val initWorkerIndex = List( - Instruction( - Op.OpAccessChain, - List( - ResultRef(ctx.inputPointerMap(ctx.valueTypeMap(Int32Tag.tag))), - ResultRef(ctx.nextResultId + 1), - ResultRef(GL_GLOBAL_INVOCATION_ID_REF), - ResultRef(ctx.constRefs(Int32Tag, 0)), - ), - ), - Instruction(Op.OpLoad, List(ResultRef(ctx.valueTypeMap(Int32Tag.tag)), ResultRef(ctx.nextResultId + 2), ResultRef(ctx.nextResultId + 1))), + var nextId = ctx.nextResultId + 1 + + def loadScalarFromVec3(varRef: Int): (List[Words], Int) = + val ptrId = nextId + val loadId = nextId + 1 + nextId += 2 + val insns = List( + Instruction(Op.OpAccessChain, List(ResultRef(int32PtrInputRef), ResultRef(ptrId), ResultRef(varRef), ResultRef(zeroConstRef))), + Instruction(Op.OpLoad, List(ResultRef(int32TypeRef), ResultRef(loadId), ResultRef(ptrId))), + ) + (insns, loadId) + + def loadVec3(varRef: Int): (List[Words], Int) = + val loadId = nextId + nextId += 1 + val insns = List(Instruction(Op.OpLoad, List(ResultRef(vec3Int32TypeRef), ResultRef(loadId), ResultRef(varRef)))) + (insns, loadId) + + def loadScalar(varRef: Int): (List[Words], Int) = + val loadId = nextId + nextId += 1 + val insns = List(Instruction(Op.OpLoad, List(ResultRef(int32TypeRef), ResultRef(loadId), ResultRef(varRef)))) + (insns, loadId) + + val (globalInvocInsns, globalInvocId) = loadScalarFromVec3(GL_GLOBAL_INVOCATION_ID_REF) + val (localIdInsns, localIdRef) = loadVec3(GL_LOCAL_INVOCATION_ID_REF) + val (localIndexInsns, localIndexRef) = loadScalar(GL_LOCAL_INVOCATION_INDEX_REF) + val (workgroupIdInsns, workgroupIdLoadRef) = loadVec3(GL_WORKGROUP_ID_REF) + val (numWorkgroupsInsns, numWorkgroupsLoadRef) = loadVec3(GL_NUM_WORKGROUPS_REF) + val (subgroupIdInsns, subgroupIdLoadRef) = loadScalar(GL_SUBGROUP_ID_REF) + val (subgroupLocalIdInsns, subgroupLocalIdLoadRef) = loadScalar(GL_SUBGROUP_LOCAL_INVOCATION_ID_REF) + val (subgroupSizeInsns, subgroupSizeLoadRef) = loadScalar(GL_SUBGROUP_SIZE_REF) + + val loadInsns = globalInvocInsns ::: localIdInsns ::: localIndexInsns ::: + workgroupIdInsns ::: numWorkgroupsInsns ::: + subgroupIdInsns ::: subgroupLocalIdInsns ::: subgroupSizeInsns + + val bodyCtx = ctx.copy( + nextResultId = nextId, + workerIndexRef = globalInvocId, + localInvocationIdRef = localIdRef, + localInvocationIndexRef = localIndexRef, + workgroupIdRef = workgroupIdLoadRef, + numWorkgroupsRef = numWorkgroupsLoadRef, + subgroupIdRef = subgroupIdLoadRef, + subgroupLocalInvocationIdRef = subgroupLocalIdLoadRef, + subgroupSizeRef = subgroupSizeLoadRef, ) - val (body, codeCtx) = GIOCompiler.compileGio(bodyIo, ctx.copy(nextResultId = ctx.nextResultId + 3, workerIndexRef = ctx.nextResultId + 2)) + val (body, codeCtx) = GIOCompiler.compileGio(bodyIo, bodyCtx) val (vars, nonVarsBody) = bubbleUpVars(body) val end = List(Instruction(Op.OpReturn, List()), Instruction(Op.OpFunctionEnd, List())) - (init ::: vars ::: initWorkerIndex ::: nonVarsBody ::: end, codeCtx.copy(nextResultId = codeCtx.nextResultId + 1)) + (init ::: vars ::: loadInsns ::: nonVarsBody ::: end, codeCtx.copy(nextResultId = codeCtx.nextResultId + 1)) def getNameDecorations(ctx: Context): List[Instruction] = val funNames = ctx.functions.map { case (id, fn) => @@ -65,25 +106,59 @@ private[cyfra] object SpirvProgramCompiler: binding: Int, ) - val headers: List[Words] = + case class SharedBlock( + arrayTypeRef: Int, + varRef: Int, + pointerTypeRef: Int, + ) + + def headers(workgroupSize: (Int, Int, Int)): List[Words] = + val (localSizeX, localSizeY, localSizeZ) = workgroupSize Word(Array(0x03, 0x02, 0x23, 0x07)) :: // SPIR-V - Word(Array(0x00, 0x00, 0x01, 0x00)) :: // Version: 0.1.0 + Word(Array(0x00, 0x03, 0x01, 0x00)) :: // Version: 1.3.0 (for GroupNonUniform) Word(Array(cyfraVendorId, 0x00, 0x01, 0x00)) :: // Generator: cyfra; 1 WordVariable(BOUND_VARIABLE) :: // Bound: To be calculated Word(Array(0x00, 0x00, 0x00, 0x00)) :: // Schema: 0 - Instruction(Op.OpCapability, List(Capability.Shader)) :: // OpCapability Shader - Instruction(Op.OpExtension, List(Text("SPV_KHR_non_semantic_info"))) :: // OpExtension "SPV_KHR_non_semantic_info" - Instruction(Op.OpExtInstImport, List(ResultRef(GLSL_EXT_REF), Text(GLSL_EXT_NAME))) :: // OpExtInstImport "GLSL.std.450" - Instruction(Op.OpExtInstImport, List(ResultRef(DEBUG_PRINTF_REF), Text(NON_SEMANTIC_DEBUG_PRINTF))) :: // OpExtInstImport "NonSemantic.DebugPrintf" - Instruction(Op.OpMemoryModel, List(AddressingModel.Logical, MemoryModel.GLSL450)) :: // OpMemoryModel Logical GLSL450 - Instruction(Op.OpEntryPoint, List(ExecutionModel.GLCompute, ResultRef(MAIN_FUNC_REF), Text("main"), ResultRef(GL_GLOBAL_INVOCATION_ID_REF))) :: // OpEntryPoint GLCompute %MAIN_FUNC_REF "main" %GL_GLOBAL_INVOCATION_ID_REF - Instruction(Op.OpExecutionMode, List(ResultRef(MAIN_FUNC_REF), ExecutionMode.LocalSize, IntWord(256), IntWord(1), IntWord(1))) :: // OpExecutionMode %4 LocalSize 128 1 1 - Instruction(Op.OpSource, List(SourceLanguage.GLSL, IntWord(450))) :: // OpSource GLSL 450 + Instruction(Op.OpCapability, List(Capability.Shader)) :: + Instruction(Op.OpCapability, List(Capability.Float16)) :: + Instruction(Op.OpCapability, List(Capability.GroupNonUniform)) :: + Instruction(Op.OpCapability, List(Capability.GroupNonUniformArithmetic)) :: + Instruction(Op.OpExtension, List(Text("SPV_KHR_non_semantic_info"))) :: + Instruction(Op.OpExtInstImport, List(ResultRef(GLSL_EXT_REF), Text(GLSL_EXT_NAME))) :: + Instruction(Op.OpExtInstImport, List(ResultRef(DEBUG_PRINTF_REF), Text(NON_SEMANTIC_DEBUG_PRINTF))) :: + Instruction(Op.OpMemoryModel, List(AddressingModel.Logical, MemoryModel.GLSL450)) :: + Instruction( + Op.OpEntryPoint, + List( + ExecutionModel.GLCompute, + ResultRef(MAIN_FUNC_REF), + Text("main"), + ResultRef(GL_GLOBAL_INVOCATION_ID_REF), + ResultRef(GL_LOCAL_INVOCATION_ID_REF), + ResultRef(GL_LOCAL_INVOCATION_INDEX_REF), + ResultRef(GL_WORKGROUP_ID_REF), + ResultRef(GL_NUM_WORKGROUPS_REF), + ResultRef(GL_SUBGROUP_ID_REF), + ResultRef(GL_SUBGROUP_LOCAL_INVOCATION_ID_REF), + ResultRef(GL_SUBGROUP_SIZE_REF), + ), + ) :: + Instruction(Op.OpExecutionMode, List(ResultRef(MAIN_FUNC_REF), ExecutionMode.LocalSize, IntWord(localSizeX), IntWord(localSizeY), IntWord(localSizeZ))) :: + Instruction(Op.OpSource, List(SourceLanguage.GLSL, IntWord(450))) :: Nil val workgroupDecorations: List[Words] = - Instruction(Op.OpDecorate, List(ResultRef(GL_GLOBAL_INVOCATION_ID_REF), Decoration.BuiltIn, BuiltIn.GlobalInvocationId)) :: // OpDecorate %GL_GLOBAL_INVOCATION_ID_REF BuiltIn GlobalInvocationId - Instruction(Op.OpDecorate, List(ResultRef(GL_WORKGROUP_SIZE_REF), Decoration.BuiltIn, BuiltIn.WorkgroupSize)) :: Nil + List( + Instruction(Op.OpDecorate, List(ResultRef(GL_GLOBAL_INVOCATION_ID_REF), Decoration.BuiltIn, BuiltIn.GlobalInvocationId)), + Instruction(Op.OpDecorate, List(ResultRef(GL_WORKGROUP_SIZE_REF), Decoration.BuiltIn, BuiltIn.WorkgroupSize)), + Instruction(Op.OpDecorate, List(ResultRef(GL_LOCAL_INVOCATION_ID_REF), Decoration.BuiltIn, BuiltIn.LocalInvocationId)), + Instruction(Op.OpDecorate, List(ResultRef(GL_LOCAL_INVOCATION_INDEX_REF), Decoration.BuiltIn, BuiltIn.LocalInvocationIndex)), + Instruction(Op.OpDecorate, List(ResultRef(GL_WORKGROUP_ID_REF), Decoration.BuiltIn, BuiltIn.WorkgroupId)), + Instruction(Op.OpDecorate, List(ResultRef(GL_NUM_WORKGROUPS_REF), Decoration.BuiltIn, BuiltIn.NumWorkgroups)), + Instruction(Op.OpDecorate, List(ResultRef(GL_SUBGROUP_ID_REF), Decoration.BuiltIn, BuiltIn.SubgroupId)), + Instruction(Op.OpDecorate, List(ResultRef(GL_SUBGROUP_LOCAL_INVOCATION_ID_REF), Decoration.BuiltIn, BuiltIn.SubgroupLocalInvocationId)), + Instruction(Op.OpDecorate, List(ResultRef(GL_SUBGROUP_SIZE_REF), Decoration.BuiltIn, BuiltIn.SubgroupSize)), + ) def defineVoids(context: Context): (List[Words], Context) = val voidDef = List[Words]( @@ -93,7 +168,8 @@ private[cyfra] object SpirvProgramCompiler: val ctxWithVoid = context.copy(voidTypeRef = TYPE_VOID_REF, voidFuncTypeRef = VOID_FUNC_TYPE_REF) (voidDef, ctxWithVoid) - def createInvocationId(context: Context): (List[Words], Context) = + def createInvocationId(context: Context, workgroupSize: (Int, Int, Int)): (List[Words], Context) = + val (localSizeX, localSizeY, localSizeZ) = workgroupSize val definitionInstructions = List( Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 0), IntWord(localSizeX))), Instruction(Op.OpConstant, List(ResultRef(context.valueTypeMap(UInt32Tag.tag)), ResultRef(context.nextResultId + 1), IntWord(localSizeY))), @@ -239,7 +315,14 @@ private[cyfra] object SpirvProgramCompiler: } } - val predefinedConsts = List((Int32Tag, 0), (UInt32Tag, 0), (Int32Tag, 1)) + val predefinedConsts = List( + (Int32Tag, 0), + (UInt32Tag, 0), + (Int32Tag, 1), + (Int32Tag, Scope.Workgroup.opcode), + (Int32Tag, Scope.Subgroup.opcode), + (Int32Tag, MemorySemantics.WorkgroupMemory.opcode | MemorySemantics.AcquireRelease.opcode), + ) def defineConstants(exprs: List[E[?]], ctx: Context): (List[Words], Context) = // Collect field indices from GetField expressions val fieldIndices = exprs.collect { case gf: GetField[?, ?] => @@ -269,16 +352,18 @@ private[cyfra] object SpirvProgramCompiler: ) def defineVarNames(ctx: Context): (List[Words], Context) = + val vec3Int32PtrId = ctx.inputPointerMap(ctx.valueTypeMap(summon[Tag[Vec3[Int32]]].tag)) + val int32PtrInputId = ctx.inputPointerMap(ctx.valueTypeMap(summon[Tag[Int32]].tag)) ( List( - Instruction( - Op.OpVariable, - List( - ResultRef(ctx.inputPointerMap(ctx.valueTypeMap(summon[Tag[Vec3[Int32]]].tag))), - ResultRef(GL_GLOBAL_INVOCATION_ID_REF), - StorageClass.Input, - ), - ), + Instruction(Op.OpVariable, List(ResultRef(vec3Int32PtrId), ResultRef(GL_GLOBAL_INVOCATION_ID_REF), StorageClass.Input)), + Instruction(Op.OpVariable, List(ResultRef(vec3Int32PtrId), ResultRef(GL_LOCAL_INVOCATION_ID_REF), StorageClass.Input)), + Instruction(Op.OpVariable, List(ResultRef(int32PtrInputId), ResultRef(GL_LOCAL_INVOCATION_INDEX_REF), StorageClass.Input)), + Instruction(Op.OpVariable, List(ResultRef(vec3Int32PtrId), ResultRef(GL_WORKGROUP_ID_REF), StorageClass.Input)), + Instruction(Op.OpVariable, List(ResultRef(vec3Int32PtrId), ResultRef(GL_NUM_WORKGROUPS_REF), StorageClass.Input)), + Instruction(Op.OpVariable, List(ResultRef(int32PtrInputId), ResultRef(GL_SUBGROUP_ID_REF), StorageClass.Input)), + Instruction(Op.OpVariable, List(ResultRef(int32PtrInputId), ResultRef(GL_SUBGROUP_LOCAL_INVOCATION_ID_REF), StorageClass.Input)), + Instruction(Op.OpVariable, List(ResultRef(int32PtrInputId), ResultRef(GL_SUBGROUP_SIZE_REF), StorageClass.Input)), ), - ctx.copy(), + ctx, ) diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala index 8c539e2c..68819575 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GExecution.scala @@ -31,6 +31,18 @@ trait GExecution[-Params, ExecLayout: Layout, ResLayout: Layout]: val adapted = program.contramapParams(mapParams).contramap(mapLayout) flatMap(r => adapted.map(_ => r)) + /** Add a GPU buffer copy operation (uses vkCmdCopyBuffer - much faster than compute shader). + * + * @param getBuffers Function to extract (source, destination) buffers from layout + * @param sizeBytes Number of bytes to copy + */ + def addBufferCopy[PP <: Params]( + getBuffers: ExecLayout => (GBuffer[?], GBuffer[?]), + sizeBytes: Int, + ): GExecution[PP, ExecLayout, ResLayout] = + val copyExec = BufferCopy[ExecLayout](getBuffers, sizeBytes) + flatMap(r => copyExec.map(_ => r)) + object GExecution: def apply[Params, L: Layout]() = @@ -44,6 +56,12 @@ object GExecution: case class FlatMap[Params, EL: Layout, RL: Layout, NRL: Layout](execution: GExecution[Params, EL, RL], f: (Params, RL) => GExecution[Params, EL, NRL]) extends GExecution[Params, EL, NRL] + /** GPU buffer copy using vkCmdCopyBuffer (DMA transfer, much faster than compute shader). */ + case class BufferCopy[L: Layout]( + getBuffers: L => (GBuffer[?], GBuffer[?]), + sizeBytes: Int, + ) extends GExecution[Any, L, L] + case class Map[P, NP, EL: Layout, NEL: Layout, RL: Layout, NRL: Layout]( execution: GExecution[P, EL, RL], mapResult: RL => NRL, diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala index 0a62f5b2..72a8fa77 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala @@ -15,11 +15,13 @@ import izumi.reflect.Tag import java.io.FileInputStream import java.nio.file.Path import scala.util.Using +import sourcecode.Enclosing trait GProgram[Params, L: Layout] extends GExecution[Params, L, L]: val layout: InitProgramLayout => Params => L val dispatch: (L, Params) => ProgramDispatch val workgroupSize: WorkDimensions + val name: String def summonLayout: Layout[L] = Layout[L] object GProgram: @@ -33,11 +35,14 @@ object GProgram: layout: InitProgramLayout ?=> Params => L, dispatch: (L, Params) => ProgramDispatch, workgroupSize: WorkDimensions = (128, 1, 1), - )(body: L => GIO[?]): GProgram[Params, L] = - new GioProgram[Params, L](body, s => layout(using s), dispatch, workgroupSize) + )(body: L => GIO[?])(using enclosing: Enclosing): GProgram[Params, L] = + // Extract program name from enclosing context (e.g. "pkg.F16RMSNormProgram.forward" -> "F16RMSNormProgram") + val programName = enclosing.value.split('.').dropRight(1).lastOption.getOrElse("Program") + new GioProgram[Params, L](body, s => layout(using s), dispatch, workgroupSize, programName) - def static[Params, L: Layout](layout: InitProgramLayout ?=> Params => L, dispatchSize: Params => Int)(body: L => GIO[?]): GProgram[Params, L] = - GioProgram.apply(body, s => layout(using s), (l, p) => StaticDispatch((dispatchSize(p) + 127) / 128, 1, 1), (128, 1, 1)) + def static[Params, L: Layout](layout: InitProgramLayout ?=> Params => L, dispatchSize: Params => Int)(body: L => GIO[?])(using Enclosing): GProgram[Params, L] = + val programName = summon[Enclosing].value.split('.').dropRight(1).lastOption.getOrElse("Program") + GioProgram.apply(body, s => layout(using s), (l, p) => StaticDispatch((dispatchSize(p) + 127) / 128, 1, 1), (128, 1, 1), programName) def fromSpirvFile[Params, L: Layout]( layout: InitProgramLayout ?=> Params => L, diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala index 2e074980..3ad93e49 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/GioProgram.scala @@ -11,4 +11,5 @@ case class GioProgram[Params, L: Layout]( layout: InitProgramLayout => Params => L, dispatch: (L, Params) => ProgramDispatch, workgroupSize: WorkDimensions, + name: String = "GioProgram", ) extends GProgram[Params, L] diff --git a/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala b/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala index 3d50b2e4..0af7807d 100644 --- a/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala +++ b/cyfra-core/src/main/scala/io/computenode/cyfra/core/SpirvProgram.scala @@ -2,7 +2,7 @@ package io.computenode.cyfra.core import io.computenode.cyfra.core.layout.Layout import io.computenode.cyfra.core.GProgram.{InitProgramLayout, ProgramDispatch, WorkDimensions} -import io.computenode.cyfra.core.SpirvProgram.Operation.ReadWrite +import io.computenode.cyfra.core.SpirvProgram.Operation.{Read, ReadWrite, Write} import io.computenode.cyfra.core.SpirvProgram.{Binding, ShaderLayout} import io.computenode.cyfra.dsl.Value import io.computenode.cyfra.dsl.Value.{FromExpr, GBoolean} @@ -28,6 +28,7 @@ case class SpirvProgram[Params, L: Layout] private ( code: ByteBuffer, entryPoint: String, shaderBindings: L => ShaderLayout, + name: String = "SpirvProgram", ) extends GProgram[Params, L]: /** A hash of the shader code, entry point, workgroup size, and layout bindings. Layout and dispatch are not taken into account. @@ -63,9 +64,29 @@ object SpirvProgram: dispatch: (L, Params) => ProgramDispatch, code: ByteBuffer, ): SpirvProgram[Params, L] = - val workgroupSize = (128, 1, 1) // TODO Extract form shader + apply(layout, dispatch, code, Set.empty, "SpirvProgram") + + /** Create a SpirvProgram with explicit write tracking for smarter barrier insertion. + * + * @param writtenBindingIndices Set of binding indices (layoutOffset) that are written to by this shader. + * Bindings not in this set are treated as read-only. + * @param name Name for profiling/debugging + */ + def apply[Params, L: Layout]( + layout: InitProgramLayout ?=> Params => L, + dispatch: (L, Params) => ProgramDispatch, + code: ByteBuffer, + writtenBindingIndices: Set[Int], + name: String, + ): SpirvProgram[Params, L] = + val workgroupSize = (128, 1, 1) // TODO Extract from shader val main = "main" val f: L => ShaderLayout = { case layout: Product => - layout.productIterator.zipWithIndex.map { case (binding: GBinding[?], i) => Binding(binding, ReadWrite) }.toSeq.pipe(Seq(_)) + layout.productIterator.zipWithIndex.map { case (binding: GBinding[?], i) => + val op = if writtenBindingIndices.isEmpty then ReadWrite // Fallback for legacy code + else if writtenBindingIndices.contains(i) then Write + else Read + Binding(binding, op) + }.toSeq.pipe(Seq(_)) } - new SpirvProgram[Params, L]((il: InitProgramLayout) => layout(using il), dispatch, workgroupSize, code, main, f) + new SpirvProgram[Params, L]((il: InitProgramLayout) => layout(using il), dispatch, workgroupSize, code, main, f, name) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala index 7d52eb5e..283fd465 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Expression.scala @@ -87,15 +87,20 @@ object Expression: sealed trait ConvertExpression[F <: Scalar: Tag, T <: Scalar: Tag] extends Expression[T]: def fromTag: Tag[F] = summon[Tag[F]] def a: F + case class ToFloat16[T <: Scalar: Tag](a: T) extends ConvertExpression[T, Float16] case class ToFloat32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, Float32] case class ToInt32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, Int32] case class ToUInt32[T <: Scalar: Tag](a: T) extends ConvertExpression[T, UInt32] + /** Convert Vec4[Float16] to Vec4[Float32] using OpFConvert. */ + case class ConvertVec4F16ToF32(a: Vec4[Float16]) extends Expression[Vec4[Float32]] + sealed trait Const[T <: Scalar: Tag] extends Expression[T]: def value: Any object Const: def unapply[T <: Scalar](c: Const[T]): Option[Any] = Some(c.value) + case class ConstFloat16(value: Float) extends Const[Float16] case class ConstFloat32(value: Float) extends Const[Float32] case class ConstInt32(value: Int) extends Const[Int32] case class ConstUInt32(value: Int) extends Const[UInt32] @@ -115,3 +120,33 @@ object Expression: case object WorkerIndex extends E[Int32] case class Binding[T <: Value: Tag](binding: Int) extends E[T] + + // Workgroup built-ins + case object LocalInvocationIndex extends E[Int32] + case object LocalInvocationId extends E[Vec3[Int32]] + case object WorkgroupId extends E[Vec3[Int32]] + case object NumWorkgroups extends E[Vec3[Int32]] + case object SubgroupId extends E[Int32] + case object SubgroupLocalInvocationId extends E[Int32] + case object SubgroupSize extends E[Int32] + + // Subgroup operations + sealed trait SubgroupOp + object SubgroupOp: + case object Reduce extends SubgroupOp + case object InclusiveScan extends SubgroupOp + case object ExclusiveScan extends SubgroupOp + + case class SubgroupAddI(value: Int32, op: SubgroupOp) extends E[Int32] + case class SubgroupAddF16(value: Float16, op: SubgroupOp) extends E[Float16] + case class SubgroupAddF(value: Float32, op: SubgroupOp) extends E[Float32] + case class SubgroupMinI(value: Int32, op: SubgroupOp) extends E[Int32] + case class SubgroupMinF16(value: Float16, op: SubgroupOp) extends E[Float16] + case class SubgroupMinF(value: Float32, op: SubgroupOp) extends E[Float32] + case class SubgroupMaxI(value: Int32, op: SubgroupOp) extends E[Int32] + case class SubgroupMaxF16(value: Float16, op: SubgroupOp) extends E[Float16] + case class SubgroupMaxF(value: Float32, op: SubgroupOp) extends E[Float32] + case class SubgroupBroadcast[T <: Value.Scalar: Tag](value: T, lane: Int32) extends E[T] + case class SubgroupBroadcastFirst[T <: Value.Scalar: Tag](value: T) extends E[T] + case class SubgroupShuffle[T <: Value.Scalar: Tag](value: T, lane: Int32) extends E[T] + case class SubgroupShuffleXor[T <: Value.Scalar: Tag](value: T, mask: Int32) extends E[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala index 1e8a0e92..de4bd094 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/Value.scala @@ -24,6 +24,16 @@ object Value: sealed trait Scalar extends Value trait FloatType extends Scalar + + /** 16-bit floating point (half precision) - supported in Vulkan for memory bandwidth savings */ + case class Float16(tree: E[Float16])(using val source: Source) extends FloatType + given FromExpr[Float16] with + def fromExpr(f: E[Float16])(using Source) = Float16(f) + + /** Factory method for creating Float16 constants */ + object Float16: + def apply(value: Float)(using Source): Float16 = Float16(Expression.ConstFloat16(value)) + case class Float32(tree: E[Float32])(using val source: Source) extends FloatType given FromExpr[Float32] with def fromExpr(f: E[Float32])(using Source) = Float32(f) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala index 475b936e..31403fe3 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/ScalarAlgebra.scala @@ -1,6 +1,6 @@ package io.computenode.cyfra.dsl.algebra -import io.computenode.cyfra.dsl.Expression.ConstFloat32 +import io.computenode.cyfra.dsl.Expression.{ConstFloat16, ConstFloat32} import io.computenode.cyfra.dsl.Value.* import io.computenode.cyfra.dsl.Expression.* import io.computenode.cyfra.dsl.library.Functions.abs @@ -22,6 +22,7 @@ object ScalarAlgebra: trait BasicScalarIntAlgebra[T <: Scalar: {FromExpr, Tag}] extends BasicScalarAlgebra[T] with BitwiseOperable[T] + given BasicScalarAlgebra[Float16] = new BasicScalarAlgebra[Float16] {} given BasicScalarAlgebra[Float32] = new BasicScalarAlgebra[Float32] {} given BasicScalarIntAlgebra[Int32] = new BasicScalarIntAlgebra[Int32] {} given BasicScalarIntAlgebra[UInt32] = new BasicScalarIntAlgebra[UInt32] {} @@ -92,16 +93,27 @@ object ScalarAlgebra: given Epsilon = Epsilon(0.00001f) + extension (f16: Float16) + inline def asFloat32(using Source): Float32 = Float32(ToFloat32(f16)) + inline def asInt(using Source): Int32 = f16.asFloat32.asInt + + extension (f32: Float32) + inline def asFloat16(using Source): Float16 = Float16(ToFloat16(f32)) + extension (f32: Float32) + /** Convert Float32 to Float16 constant for DSL usage */ + inline def toF16(using Source): Float16 = Float16(ToFloat16(f32)) inline def asInt(using Source): Int32 = Int32(ToInt32(f32)) inline def =~=(other: Float32)(using epsilon: Epsilon): GBoolean = abs(f32 - other) < epsilon.eps extension (i32: Int32) + inline def asFloat16(using Source): Float16 = Float16(ToFloat16(i32)) inline def asFloat(using Source): Float32 = Float32(ToFloat32(i32)) inline def unsigned(using Source): UInt32 = UInt32(ToUInt32(i32)) - + extension (u32: UInt32) + inline def asFloat16(using Source): Float16 = Float16(ToFloat16(u32)) inline def asFloat(using Source): Float32 = Float32(ToFloat32(u32)) inline def signed(using Source): Int32 = Int32(ToInt32(u32)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala index 7908f63f..eee2b0e0 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/algebra/VectorAlgebra.scala @@ -119,6 +119,10 @@ object VectorAlgebra: inline def xyz(using Source): Vec3[T] = Vec3(ComposeVec3(x, y, z)) inline def rgb(using Source): Vec3[T] = xyz + /** Convert Vec4[Float16] to Vec4[Float32] for higher precision operations. */ + extension (v4f16: Vec4[Float16]) + inline def asVec4F32(using Source): Vec4[Float32] = Vec4(ConvertVec4F16ToF32(v4f16)) + given (using Source): Conversion[(Int, Int), Vec2[Int32]] = { case (x, y) => Vec2(ComposeVec2(Int32(ConstInt32(x)), Int32(ConstInt32(y)))) } diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GBinding.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GBinding.scala index 27f25d04..5510623a 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GBinding.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GBinding.scala @@ -2,7 +2,7 @@ package io.computenode.cyfra.dsl.binding import io.computenode.cyfra.dsl.Value import io.computenode.cyfra.dsl.Value.FromExpr.fromExpr as fromExprEval -import io.computenode.cyfra.dsl.Value.{FromExpr, Int32} +import io.computenode.cyfra.dsl.Value.{FloatType, FromExpr, Int32, Vec4} import io.computenode.cyfra.dsl.gio.GIO import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} import io.computenode.cyfra.dsl.struct.GStruct.Empty @@ -17,7 +17,15 @@ trait GBuffer[T <: Value: {FromExpr, Tag}] extends GBinding[T]: def write(index: Int32, value: T): GIO[Empty] = GIO.write(this, index, value) -object GBuffer +object GBuffer: + /** Extension to read 4 consecutive elements as Vec4 from a float buffer. + * @param buffer The buffer to read from + * @param index Base index - reads elements at index, index+1, index+2, index+3 + * @return Vec4 containing the 4 consecutive values + */ + extension [T <: FloatType: Tag: FromExpr](buffer: GBuffer[T]) + def readVec4(index: Int32)(using Tag[Vec4[T]]): Vec4[T] = + FromExpr.fromExpr[Vec4[T]](ReadBufferVec4(buffer, index)) trait GUniform[T <: GStruct[?]: {Tag, FromExpr, GStructSchema}] extends GBinding[T]: def read: T = fromExprEval(ReadUniform(this)) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GShared.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GShared.scala new file mode 100644 index 00000000..3d4c55d5 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/GShared.scala @@ -0,0 +1,39 @@ +package io.computenode.cyfra.dsl.binding + +import io.computenode.cyfra.dsl.Expression.E +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.{FromExpr, Int32} +import io.computenode.cyfra.dsl.Value.FromExpr.fromExpr +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import izumi.reflect.Tag + +/** + * Represents a workgroup-local shared memory array. + * + * Shared memory is visible to all invocations within a workgroup and can be used + * for efficient inter-thread communication within a workgroup after synchronization + * with [[GIO.barrier]]. + * + * @tparam T Element type of the shared memory array + */ +trait GShared[T <: Value: {FromExpr, Tag}]: + def tag: Tag[T] = summon[Tag[T]] + def size: Int + + /** Read a value from shared memory at the given index. */ + def read(index: Int32): T = fromExpr(ReadShared(this, index)) + + /** Write a value to shared memory at the given index. */ + def write(index: Int32, value: T): GIO[Empty] = WriteShared(this, index, value) + +object GShared: + private var nextId = 0 + + /** Create a shared memory array with the given size. */ + def apply[T <: Value: {FromExpr, Tag}](size: Int): GShared[T] = + val id = nextId + nextId += 1 + new GSharedImpl[T](id, size) + + private[cyfra] class GSharedImpl[T <: Value: {FromExpr, Tag}](val sharedId: Int, val size: Int) extends GShared[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadBuffer.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadBuffer.scala index e0057720..159b67ed 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadBuffer.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadBuffer.scala @@ -1,7 +1,17 @@ package io.computenode.cyfra.dsl.binding -import io.computenode.cyfra.dsl.Value.Int32 +import io.computenode.cyfra.dsl.Value.{Float16, Float32, Int32, Vec4} import io.computenode.cyfra.dsl.{Expression, Value} import izumi.reflect.Tag case class ReadBuffer[T <: Value: Tag](buffer: GBuffer[T], index: Int32) extends Expression[T] + +/** Reads 4 consecutive elements from a buffer and returns them as a Vec4. + * The index is the base index - elements at index, index+1, index+2, index+3 are read. + * This compiles to 4 scalar loads + OpCompositeConstruct. + * + * Note: For truly coalesced Vec4 loads, use GBufferVec4 which declares the buffer + * with Vec4 element type. + */ +case class ReadBufferVec4[T <: Value.FloatType: Tag](buffer: GBuffer[T], index: Int32)(using vecTag: Tag[Vec4[T]]) + extends Expression[Vec4[T]](using vecTag) diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadShared.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadShared.scala new file mode 100644 index 00000000..8da5c592 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/ReadShared.scala @@ -0,0 +1,11 @@ +package io.computenode.cyfra.dsl.binding + +import io.computenode.cyfra.dsl.Expression +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.{FromExpr, Int32} +import izumi.reflect.Tag + +case class ReadShared[T <: Value: {Tag, FromExpr}]( + buffer: GShared[T], + index: Int32, +) extends Expression[T] diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteBuffer.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteBuffer.scala index 1856079a..eae1642d 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteBuffer.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteBuffer.scala @@ -6,4 +6,6 @@ import io.computenode.cyfra.dsl.gio.GIO import io.computenode.cyfra.dsl.struct.GStruct.Empty case class WriteBuffer[T <: Value](buffer: GBuffer[T], index: Int32, value: T) extends GIO[Empty]: - override def underlying: Empty = Empty() + // Cache the underlying value to ensure stable treeid for compiler lookups + private lazy val _underlying: Empty = Empty() + override def underlying: Empty = _underlying diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteShared.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteShared.scala new file mode 100644 index 00000000..00808863 --- /dev/null +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/binding/WriteShared.scala @@ -0,0 +1,16 @@ +package io.computenode.cyfra.dsl.binding + +import io.computenode.cyfra.dsl.Value +import io.computenode.cyfra.dsl.Value.{FromExpr, Int32} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.GStruct.Empty +import izumi.reflect.Tag + +case class WriteShared[T <: Value: {Tag, FromExpr}]( + buffer: GShared[T], + index: Int32, + value: T, +) extends GIO[Empty]: + // Cache the underlying value to ensure stable treeid for compiler lookups + private lazy val _underlying: Empty = Empty() + override def underlying: Empty = _underlying diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala index b4265a1b..03eb23d2 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/collections/GSeq.scala @@ -17,6 +17,7 @@ class GSeq[T <: Value: {Tag, FromExpr}]( val name: Source, val currentElemExprTreeId: Int = treeidState.getAndIncrement(), val aggregateElemExprTreeId: Int = treeidState.getAndIncrement(), + val shouldUnroll: Boolean = false, ): def copyWithDynamicTrees[R <: Value: {Tag, FromExpr}]( @@ -24,7 +25,8 @@ class GSeq[T <: Value: {Tag, FromExpr}]( limit: Option[Int] = limit, currentElemExprTreeId: Int = currentElemExprTreeId, aggregateElemExprTreeId: Int = aggregateElemExprTreeId, - ) = GSeq[R](uninitSource, elemOps, limit, name, currentElemExprTreeId, aggregateElemExprTreeId) + shouldUnroll: Boolean = shouldUnroll, + ) = GSeq[R](uninitSource, elemOps, limit, name, currentElemExprTreeId, aggregateElemExprTreeId, shouldUnroll) private val currentElemExpr = CurrentElem[T](currentElemExprTreeId) val source = uninitSource(currentElemExpr) @@ -43,8 +45,15 @@ class GSeq[T <: Value: {Tag, FromExpr}]( def limit(n: Int): GSeq[T] = this.copyWithDynamicTrees(limit = Some(n)) + /** Mark this sequence for loop unrolling in the generated shader. + * This generates [[unroll]] pragma in GLSL, which hints the compiler + * to fully unroll the loop for better performance on small fixed-size loops. + */ + def unroll: GSeq[T] = + this.copyWithDynamicTrees(shouldUnroll = true) + def fold[R <: Value: {Tag, FromExpr}](zero: R, fn: (R, T) => R): R = - summon[FromExpr[R]].fromExpr(GSeq.FoldSeq(zero, fn(aggregateElem, currentElem).tree, this)) + summon[FromExpr[R]].fromExpr(GSeq.FoldSeq(zero, fn(aggregateElem, currentElem).tree, this, shouldUnroll)) def count: Int32 = fold(0, (acc: Int32, _: T) => acc + 1) @@ -90,7 +99,7 @@ object GSeq: sealed trait GSeqSource[T <: Value: Tag] case class GSeqStream[T <: Value: Tag](init: T, next: Expression[?]) extends GSeqSource[T] - case class FoldSeq[R <: Value: Tag, T <: Value: Tag](zero: R, fn: Expression[?], seq: GSeq[T]) extends Expression[R]: + case class FoldSeq[R <: Value: Tag, T <: Value: Tag](zero: R, fn: Expression[?], seq: GSeq[T], unroll: Boolean = false) extends Expression[R]: val zeroExpr = zero.tree val fnExpr = fn val streamInitExpr = seq.source.init.tree diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala index 09373068..1f5988e1 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/gio/GIO.scala @@ -1,7 +1,8 @@ package io.computenode.cyfra.dsl.gio import io.computenode.cyfra.dsl.{*, given} -import io.computenode.cyfra.dsl.Value.{FromExpr, Int32} +import io.computenode.cyfra.dsl.Expression.{CustomTreeId, PhantomExpression, treeidState, *, given} +import io.computenode.cyfra.dsl.Value.{FromExpr, Int32, UInt32, Float16, Float32, Vec3, Vec4} import io.computenode.cyfra.dsl.Value.FromExpr.fromExpr import io.computenode.cyfra.dsl.binding.{GBuffer, ReadBuffer, WriteBuffer} import io.computenode.cyfra.dsl.collections.GSeq @@ -10,6 +11,10 @@ import io.computenode.cyfra.dsl.struct.GStruct.Empty import io.computenode.cyfra.dsl.control.When import izumi.reflect.Tag +/** + * GPU I/O monad for representing side-effectful GPU operations. + * Supports buffer reads/writes, synchronization barriers, and workgroup-level operations. + */ trait GIO[T <: Value]: def flatMap[U <: Value](f: T => GIO[U]): GIO[U] = FlatMap(this, f(this.underlying)) @@ -26,13 +31,39 @@ object GIO: case class FlatMap[T <: Value, U <: Value](gio: GIO[T], next: GIO[U]) extends GIO[U]: override def underlying: U = next.underlying - // TODO repeat that collects results - case class Repeat(n: Int32, f: GIO[?]) extends GIO[Empty]: + /** Loop that repeats n times without accumulator. + * + * @param n Number of iterations + * @param f Body GIO to execute + * @param unroll Whether to hint the GPU compiler to unroll this loop + */ + case class Repeat(n: Int32, f: GIO[?], unroll: Boolean = false) extends GIO[Empty]: override def underlying: Empty = Empty() + /** Folding repeat with accumulator - enables accumulation across iterations with barriers. + * + * @param n Number of iterations + * @param init Initial accumulator value + * @param body Body GIO that returns new accumulator + * @param accTreeId Treeid of the CurrentFoldRepeatAcc phantom for binding + * @param unroll Whether to hint the GPU compiler to unroll this loop + */ + case class FoldRepeat[A <: Value](n: Int32, init: A, body: GIO[A], accTreeId: Int, unroll: Boolean = false) extends GIO[A]: + override def underlying: A = body.underlying + case class Printf(format: String, args: Value*) extends GIO[Empty]: override def underlying: Empty = Empty() + /** Conditional execution - executes body only if condition is true. + * Compiled to proper if-then structure (OpSelectionMerge + OpBranchConditional). + */ + case class ConditionalWhen(cond: GBoolean, body: GIO[?]) extends GIO[Empty]: + override def underlying: Empty = Empty() + + /** Memory and execution barrier for workgroup synchronization. */ + case object WorkgroupBarrier extends GIO[Empty]: + override def underlying: Empty = Empty() + def pure[T <: Value](value: T): GIO[T] = Pure(value) def value[T <: Value](value: T): GIO[T] = Pure(value) @@ -40,8 +71,48 @@ object GIO: case object CurrentRepeatIndex extends PhantomExpression[Int32] with CustomTreeId: override val treeid: Int = treeidState.getAndIncrement() + /** Phantom expression for the current accumulator value in foldRepeat. */ + case class CurrentFoldRepeatAcc[A <: Value: Tag](init: A, tid: Int) extends PhantomExpression[A] with CustomTreeId: + override val treeid: Int = tid + def repeat(n: Int32)(f: Int32 => GIO[?]): GIO[Empty] = - Repeat(n, f(fromExpr(CurrentRepeatIndex))) + Repeat(n, f(fromExpr(CurrentRepeatIndex)), unroll = false) + + /** Repeat with loop unroll hint. The GPU compiler will attempt to fully unroll + * this loop for better performance. Use for small, fixed-size loops. + */ + def repeatUnroll(n: Int32)(f: Int32 => GIO[?]): GIO[Empty] = + Repeat(n, f(fromExpr(CurrentRepeatIndex)), unroll = true) + + /** Folding repeat - accumulates a value across iterations, supporting barriers. + * + * Unlike `GSeq.fold`, this supports side effects (barriers, writes) within the loop body. + * The body receives the current iteration index and current accumulator value, + * and returns the new accumulator value wrapped in GIO. + * + * @param n Number of iterations + * @param init Initial accumulator value + * @param body Function taking (iterationIndex, currentAcc) and returning new acc in GIO + * @return Final accumulated value + */ + def foldRepeat[A <: Value: {FromExpr, Tag}](n: Int32, init: A)(body: (Int32, A) => GIO[A]): GIO[A] = + val tid = treeidState.getAndIncrement() + val accExpr = CurrentFoldRepeatAcc(init, tid) + FoldRepeat(n, init, body(fromExpr(CurrentRepeatIndex), fromExpr(accExpr)), tid, unroll = false) + + /** Folding repeat with loop unroll hint. The GPU compiler will attempt to fully + * unroll this loop for better performance. Use for small, fixed-size inner loops + * (e.g., head dimension in attention, vector dot products). + * + * @param n Number of iterations (should be a small constant for effective unrolling) + * @param init Initial accumulator value + * @param body Function taking (iterationIndex, currentAcc) and returning new acc in GIO + * @return Final accumulated value + */ + def foldRepeatUnroll[A <: Value: {FromExpr, Tag}](n: Int32, init: A)(body: (Int32, A) => GIO[A]): GIO[A] = + val tid = treeidState.getAndIncrement() + val accExpr = CurrentFoldRepeatAcc(init, tid) + FoldRepeat(n, init, body(fromExpr(CurrentRepeatIndex), fromExpr(accExpr)), tid, unroll = true) def write[T <: Value](buffer: GBuffer[T], index: Int32, value: T): GIO[Empty] = WriteBuffer(buffer, index, value) @@ -50,12 +121,147 @@ object GIO: Printf(s"|$format", args*) def when(cond: GBoolean)(thenCode: GIO[?]): GIO[Empty] = - val n = When.when(cond)(1: Int32).otherwise(0) - repeat(n): _ => - thenCode + ConditionalWhen(cond, thenCode) def read[T <: Value: {FromExpr, Tag}](buffer: GBuffer[T], index: Int32): T = fromExpr(ReadBuffer(buffer, index)) + import scala.annotation.targetName + + // ───────────────────────────────────────────────────────────────────────────── + // Global Invocation + // ───────────────────────────────────────────────────────────────────────────── + + /** Global invocation index (gl_GlobalInvocationID.x). */ def invocationId: Int32 = fromExpr(InvocationId) + + // ───────────────────────────────────────────────────────────────────────────── + // Workgroup Primitives + // ───────────────────────────────────────────────────────────────────────────── + + /** Local invocation index within workgroup (gl_LocalInvocationIndex). */ + def localInvocationIndex: Int32 = + fromExpr(LocalInvocationIndex) + + /** Local invocation ID as 3D vector (gl_LocalInvocationID). */ + def localInvocationId: Vec3[Int32] = + fromExpr(LocalInvocationId) + + /** Workgroup ID as 3D vector (gl_WorkGroupID). */ + def workgroupId: Vec3[Int32] = + fromExpr(WorkgroupId) + + /** Number of workgroups as 3D vector (gl_NumWorkGroups). */ + def numWorkgroups: Vec3[Int32] = + fromExpr(NumWorkgroups) + + /** Synchronization barrier for workgroup memory and execution. */ + def barrier: GIO[Empty] = WorkgroupBarrier + + // ───────────────────────────────────────────────────────────────────────────── + // Subgroup Primitives + // ───────────────────────────────────────────────────────────────────────────── + + /** Subgroup ID within the workgroup. */ + def subgroupId: Int32 = + fromExpr(SubgroupId) + + /** Local invocation ID within the subgroup. */ + def subgroupLocalInvocationId: Int32 = + fromExpr(SubgroupLocalInvocationId) + + /** Size of subgroup (typically 32 for NVIDIA, 64 for AMD). */ + def subgroupSize: Int32 = + fromExpr(SubgroupSize) + + // ───────────────────────────────────────────────────────────────────────────── + // Subgroup Collective Operations + // ───────────────────────────────────────────────────────────────────────────── + + /** Reduces values across the subgroup using addition. */ + def subgroupAdd(value: Int32): Int32 = + fromExpr(SubgroupAddI(value, SubgroupOp.Reduce)) + + /** Reduces values across the subgroup using addition. */ + @targetName("subgroupAddF16") + def subgroupAdd(value: Float16): Float16 = + fromExpr(SubgroupAddF16(value, SubgroupOp.Reduce)) + + /** Reduces values across the subgroup using addition. */ + def subgroupAdd(value: Float32): Float32 = + fromExpr(SubgroupAddF(value, SubgroupOp.Reduce)) + + /** Inclusive prefix sum across the subgroup. */ + def subgroupInclusiveAdd(value: Int32): Int32 = + fromExpr(SubgroupAddI(value, SubgroupOp.InclusiveScan)) + + /** Inclusive prefix sum across the subgroup. */ + @targetName("subgroupInclusiveAddF16") + def subgroupInclusiveAdd(value: Float16): Float16 = + fromExpr(SubgroupAddF16(value, SubgroupOp.InclusiveScan)) + + /** Inclusive prefix sum across the subgroup. */ + def subgroupInclusiveAdd(value: Float32): Float32 = + fromExpr(SubgroupAddF(value, SubgroupOp.InclusiveScan)) + + /** Exclusive prefix sum across the subgroup. */ + def subgroupExclusiveAdd(value: Int32): Int32 = + fromExpr(SubgroupAddI(value, SubgroupOp.ExclusiveScan)) + + /** Exclusive prefix sum across the subgroup. */ + @targetName("subgroupExclusiveAddF16") + def subgroupExclusiveAdd(value: Float16): Float16 = + fromExpr(SubgroupAddF16(value, SubgroupOp.ExclusiveScan)) + + /** Exclusive prefix sum across the subgroup. */ + def subgroupExclusiveAdd(value: Float32): Float32 = + fromExpr(SubgroupAddF(value, SubgroupOp.ExclusiveScan)) + + /** Reduces values across the subgroup using minimum. */ + def subgroupMin(value: Int32): Int32 = + fromExpr(SubgroupMinI(value, SubgroupOp.Reduce)) + + /** Reduces values across the subgroup using minimum. */ + @targetName("subgroupMinF16") + def subgroupMin(value: Float16): Float16 = + fromExpr(SubgroupMinF16(value, SubgroupOp.Reduce)) + + /** Reduces values across the subgroup using minimum. */ + def subgroupMin(value: Float32): Float32 = + fromExpr(SubgroupMinF(value, SubgroupOp.Reduce)) + + /** Reduces values across the subgroup using maximum. */ + def subgroupMax(value: Int32): Int32 = + fromExpr(SubgroupMaxI(value, SubgroupOp.Reduce)) + + /** Reduces values across the subgroup using maximum. */ + @targetName("subgroupMaxF16") + def subgroupMax(value: Float16): Float16 = + fromExpr(SubgroupMaxF16(value, SubgroupOp.Reduce)) + + /** Reduces values across the subgroup using maximum. */ + def subgroupMax(value: Float32): Float32 = + fromExpr(SubgroupMaxF(value, SubgroupOp.Reduce)) + + /** Broadcasts a value from a specific lane to all lanes in the subgroup. */ + def subgroupBroadcast[T <: Value.Scalar: {FromExpr, Tag}](value: T, lane: Int32): T = + fromExpr(SubgroupBroadcast(value, lane)) + + /** Broadcasts a value from the first active lane to all lanes in the subgroup. */ + def subgroupBroadcastFirst[T <: Value.Scalar: {FromExpr, Tag}](value: T): T = + fromExpr(SubgroupBroadcastFirst(value)) + + /** Shuffles a value from another lane in the subgroup. */ + def subgroupShuffle[T <: Value.Scalar: {FromExpr, Tag}](value: T, lane: Int32): T = + fromExpr(SubgroupShuffle(value, lane)) + + /** Shuffles a value using XOR of lane index with mask. + * This is useful for butterfly/tree reductions where each thread exchanges + * data with thread at (laneId XOR mask). For example: + * - mask=1: lanes 0↔1, 2↔3, 4↔5, ... + * - mask=2: lanes 0↔2, 1↔3, 4↔6, ... + * - mask=4: lanes 0↔4, 1↔5, 2↔6, ... + */ + def subgroupShuffleXor[T <: Value.Scalar: {FromExpr, Tag}](value: T, mask: Int32): T = + fromExpr(SubgroupShuffleXor(value, mask)) \ No newline at end of file diff --git a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala index 26b4a970..6abd3dcb 100644 --- a/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala +++ b/cyfra-dsl/src/main/scala/io/computenode/cyfra/dsl/library/Functions.scala @@ -13,26 +13,33 @@ object Functions: case object Sin extends FunctionName def sin(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Sin, List(v))) + def sin(v: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Sin, List(v))) case object Cos extends FunctionName def cos(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Cos, List(v))) + def cos(v: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Cos, List(v))) def cos[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Cos, List(v))) case object Tan extends FunctionName def tan(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Tan, List(v))) + def tan(v: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Tan, List(v))) case object Acos extends FunctionName def acos(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Acos, List(v))) + def acos(v: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Acos, List(v))) case object Asin extends FunctionName def asin(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Asin, List(v))) + def asin(v: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Asin, List(v))) case object Atan extends FunctionName def atan(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Atan, List(v))) + def atan(v: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Atan, List(v))) case object Atan2 extends FunctionName def atan2(y: Float32, x: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Atan2, List(y, x))) + def atan2(y: Float16, x: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Atan2, List(y, x))) case object Len2 extends FunctionName def length[T <: Scalar: Tag](v: Vec2[T])(using Source): Float32 = Float32(ExtFunctionCall(Len2, List(v))) @@ -43,14 +50,18 @@ object Functions: case object Pow extends FunctionName def pow(v: Float32, p: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Pow, List(v, p))) + def pow(v: Float16, p: Float16)(using Source): Float16 = + Float16(ExtFunctionCall(Pow, List(v, p))) def pow[V <: Vec[?]: {Tag, FromExpr}](v: V, p: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Pow, List(v, p))) case object Smoothstep extends FunctionName def smoothstep(edge0: Float32, edge1: Float32, x: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Smoothstep, List(edge0, edge1, x))) + def smoothstep(edge0: Float16, edge1: Float16, x: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Smoothstep, List(edge0, edge1, x))) case object Sqrt extends FunctionName def sqrt(v: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Sqrt, List(v))) + def sqrt(v: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Sqrt, List(v))) case object Cross extends FunctionName def cross[T <: Scalar: Tag](v1: Vec3[T], v2: Vec3[T])(using Source): Vec3[T] = Vec3(ExtFunctionCall(Cross, List(v1, v2))) @@ -61,12 +72,14 @@ object Functions: case object Exp extends FunctionName def exp(f: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Exp, List(f))) + def exp(f: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Exp, List(f))) def exp[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Exp, List(v))) case object Max extends FunctionName def max(f1: Float32, f2: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Max, List(f1, f2))) def max(f1: Float32, f2: Float32, fx: Float32*)(using Source): Float32 = fx.foldLeft(max(f1, f2))((a, b) => max(a, b)) + def max(f1: Float16, f2: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Max, List(f1, f2))) def max[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Max, List(v1, v2))) def max[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V, vx: V*)(using Source): V = @@ -75,6 +88,7 @@ object Functions: case object Min extends FunctionName def min(f1: Float32, f2: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Min, List(f1, f2))) def min(f1: Float32, f2: Float32, fx: Float32*)(using Source): Float32 = fx.foldLeft(min(f1, f2))((a, b) => min(a, b)) + def min(f1: Float16, f2: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Min, List(f1, f2))) def min[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Min, List(v1, v2))) def min[V <: Vec[Float32]: {Tag, FromExpr}](v1: V, v2: V, vx: V*)(using Source): V = @@ -83,6 +97,7 @@ object Functions: // todo add F/U/S to all functions that need it case object Abs extends FunctionName def abs(f: Float32)(using Source): Float32 = Float32(ExtFunctionCall(Abs, List(f))) + def abs(f: Float16)(using Source): Float16 = Float16(ExtFunctionCall(Abs, List(f))) def abs[V <: Vec[Float32]: {Tag, FromExpr}](v: V)(using Source): V = summon[FromExpr[V]].fromExpr(ExtFunctionCall(Abs, List(v))) diff --git a/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/WorkgroupPrimitivesE2eTest.scala b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/WorkgroupPrimitivesE2eTest.scala new file mode 100644 index 00000000..282682a0 --- /dev/null +++ b/cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/dsl/WorkgroupPrimitivesE2eTest.scala @@ -0,0 +1,280 @@ +package io.computenode.cyfra.e2e.dsl + +import io.computenode.cyfra.core.{GBufferRegion, GProgram} +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.binding.{GBuffer, GShared} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.GStruct +import io.computenode.cyfra.runtime.VkCyfraRuntime + +class WorkgroupPrimitivesE2eTest extends munit.FunSuite: + + case class TestLayout(output: GBuffer[Int32]) derives Layout + + test("localInvocationIndex returns correct values"): + VkCyfraRuntime.using: + val size = 512 + val program = GProgram.static[Unit, TestLayout]( + layout = _ => TestLayout(GBuffer[Int32](size)), + dispatchSize = _ => size, + ): layout => + val idx = GIO.invocationId + val localIdx = GIO.localInvocationIndex + GIO.when(idx < size): + GIO.write(layout.output, idx, localIdx) + + val resultBuf = new Array[Int](size) + val region = GBufferRegion + .allocate[TestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = TestLayout(output = GBuffer[Int32](size)), + onDone = layout => layout.output.readArray(resultBuf), + ) + + val expected = (0 until size).map(_ % 256).toArray + assert(resultBuf.toSeq == expected.toSeq, s"Local invocation indices mismatch") + + test("workgroupId.x returns correct values"): + VkCyfraRuntime.using: + val size = 512 + val program = GProgram.static[Unit, TestLayout]( + layout = _ => TestLayout(GBuffer[Int32](size)), + dispatchSize = _ => size, + ): layout => + val idx = GIO.invocationId + val wgId = GIO.workgroupId.x + GIO.when(idx < size): + GIO.write(layout.output, idx, wgId) + + val resultBuf = new Array[Int](size) + val region = GBufferRegion + .allocate[TestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = TestLayout(output = GBuffer[Int32](size)), + onDone = layout => layout.output.readArray(resultBuf), + ) + + val expected = (0 until size).map(_ / 256).toArray + assert(resultBuf.toSeq == expected.toSeq, s"Workgroup IDs mismatch") + + test("barrier compiles and executes without error"): + VkCyfraRuntime.using: + val size = 256 + val program = GProgram.static[Unit, TestLayout]( + layout = _ => TestLayout(GBuffer[Int32](size)), + dispatchSize = _ => size, + ): layout => + val idx = GIO.invocationId + GIO.write(layout.output, idx, idx) + .flatMap(_ => GIO.barrier) + .flatMap(_ => GIO.pure(layout.output.read(idx))) + .flatMap(value => GIO.write(layout.output, idx, value + 1)) + + val resultBuf = new Array[Int](size) + val region = GBufferRegion + .allocate[TestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = TestLayout(output = GBuffer[Int32](size)), + onDone = layout => layout.output.readArray(resultBuf), + ) + + val expected = (0 until size).map(_ + 1).toArray + assert(resultBuf.toSeq == expected.toSeq, s"Barrier test: expected values incremented by 1") + + test("subgroupSize returns a valid value"): + VkCyfraRuntime.using: + val size = 256 + val program = GProgram.static[Unit, TestLayout]( + layout = _ => TestLayout(GBuffer[Int32](size)), + dispatchSize = _ => size, + ): layout => + val idx = GIO.invocationId + val sgSize = GIO.subgroupSize + GIO.when(idx < size): + GIO.write(layout.output, idx, sgSize) + + val resultBuf = new Array[Int](size) + val region = GBufferRegion + .allocate[TestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = TestLayout(output = GBuffer[Int32](size)), + onDone = layout => layout.output.readArray(resultBuf), + ) + + assert(resultBuf.forall(_ > 0), s"Subgroup size should be positive") + assert(resultBuf.forall(_ <= 128), s"Subgroup size should be <= 128") + val uniqueValues = resultBuf.distinct + assert(uniqueValues.length == 1, s"All invocations should report the same subgroup size") + + test("shared memory allows workgroup communication".ignore): + VkCyfraRuntime.using: + val workgroupSize = 256 + val shared = GShared[Int32](workgroupSize) + + val program = GProgram.static[Unit, TestLayout]( + layout = _ => TestLayout(GBuffer[Int32](workgroupSize)), + dispatchSize = _ => workgroupSize, + ): layout => + val localIdx = GIO.localInvocationIndex + val globalIdx = GIO.invocationId + shared.write(localIdx, globalIdx) + .flatMap(_ => GIO.barrier) + .flatMap: _ => + val reversedIdx: Int32 = (workgroupSize - 1: Int32) - localIdx + val valueFromReversed = shared.read(reversedIdx) + layout.output.write(globalIdx, valueFromReversed) + + val resultBuf = new Array[Int](workgroupSize) + val region = GBufferRegion + .allocate[TestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = TestLayout(output = GBuffer[Int32](workgroupSize)), + onDone = layout => layout.output.readArray(resultBuf), + ) + + val expected = (0 until workgroupSize).map(i => workgroupSize - 1 - i).toArray + assert(resultBuf.toSeq == expected.toSeq, s"Shared memory communication failed") + + test("subgroupAdd reduces values within subgroup"): + VkCyfraRuntime.using: + val size = 256 + + val program = GProgram.static[Unit, TestLayout]( + layout = _ => TestLayout(GBuffer[Int32](size)), + dispatchSize = _ => size, + ): layout => + val idx = GIO.invocationId + val sum = GIO.subgroupAdd(1: Int32) + GIO.when(idx < size): + GIO.write(layout.output, idx, sum) + + val resultBuf = new Array[Int](size) + val region = GBufferRegion + .allocate[TestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = TestLayout(output = GBuffer[Int32](size)), + onDone = layout => layout.output.readArray(resultBuf), + ) + + val subgroupSizeActual = resultBuf.head + assert(subgroupSizeActual > 0, s"Subgroup sum should be positive") + assert(resultBuf.forall(_ == subgroupSizeActual), s"All lanes should have same subgroup sum (subgroup size)") + + test("subgroupInclusiveAdd computes prefix sums"): + VkCyfraRuntime.using: + val size = 256 + + val program = GProgram.static[Unit, TestLayout]( + layout = _ => TestLayout(GBuffer[Int32](size)), + dispatchSize = _ => size, + ): layout => + val idx = GIO.invocationId + val prefixSum = GIO.subgroupInclusiveAdd(1: Int32) + GIO.when(idx < size): + GIO.write(layout.output, idx, prefixSum) + + val resultBuf = new Array[Int](size) + val region = GBufferRegion + .allocate[TestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = TestLayout(output = GBuffer[Int32](size)), + onDone = layout => layout.output.readArray(resultBuf), + ) + + val subgroupSize = resultBuf.sliding(2).find { case Array(a, b) => b < a }.map(_(0)).getOrElse(resultBuf.last) + assert(subgroupSize > 0, s"Should detect subgroup size from prefix sums") + + test("subgroupBroadcast broadcasts value from specified lane"): + VkCyfraRuntime.using: + val size = 256 + + val program = GProgram.static[Unit, TestLayout]( + layout = _ => TestLayout(GBuffer[Int32](size)), + dispatchSize = _ => size, + ): layout => + val idx = GIO.invocationId + val subgroupLaneId = GIO.subgroupLocalInvocationId + val broadcasted = GIO.subgroupBroadcast(subgroupLaneId, 0: Int32) + GIO.when(idx < size): + GIO.write(layout.output, idx, broadcasted) + + val resultBuf = new Array[Int](size) + val region = GBufferRegion + .allocate[TestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = TestLayout(output = GBuffer[Int32](size)), + onDone = layout => layout.output.readArray(resultBuf), + ) + + assert(resultBuf.forall(_ == 0), s"All lanes should have received broadcast value 0 from lane 0") + + case class FoldTestLayout(input: GBuffer[Float32], output: GBuffer[Float32]) derives Layout + + test("GSeq.fold with subgroupAdd works together"): + VkCyfraRuntime.using: + val size = 256 + val iterations = 4 + + val program = GProgram.static[Unit, FoldTestLayout]( + layout = _ => FoldTestLayout(GBuffer[Float32](size), GBuffer[Float32](size)), + dispatchSize = _ => size, + ): layout => + import io.computenode.cyfra.dsl.collections.GSeq + val idx = GIO.invocationId + val laneId = GIO.subgroupLocalInvocationId + val warpSize = GIO.subgroupSize + + // Each lane computes a partial sum using fold + val partialSum: Float32 = GSeq + .gen[Int32](laneId, _ + warpSize) + .limit(iterations) + .fold(0.0f, (sum: Float32, i: Int32) => { + when(i < size)(sum + GIO.read[Float32](layout.input, i)).otherwise(sum) + }) + + // Then reduce across subgroup + val totalSum: Float32 = GIO.subgroupAdd(partialSum) + + GIO.when(idx < size): + GIO.write(layout.output, idx, totalSum) + + import java.nio.{ByteBuffer, ByteOrder} + val inputBuf = ByteBuffer.allocateDirect(size * 4).order(ByteOrder.nativeOrder()) + inputBuf.asFloatBuffer().put(Array.fill(size)(1.0f)) + inputBuf.rewind() + val resultBuf = new Array[Float](size) + + val region = GBufferRegion + .allocate[FoldTestLayout] + .map(l => program.execute((), l)) + + region.runUnsafe( + init = FoldTestLayout( + input = GBuffer[Float32](inputBuf), + output = GBuffer[Float32](size), + ), + onDone = layout => layout.output.readArray(resultBuf), + ) + + // Each invocation should have the sum of the elements it processed + reduced across subgroup + // With iterations=4 and warpSize=32, each lane processes ~4 elements worth of indices + // But with bounds check, only valid indices contribute + assert(resultBuf.forall(_ > 0), s"Total sum should be positive, got ${resultBuf.take(10).mkString(", ")}") diff --git a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/examples/GFunctionExamples.scala b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/examples/GFunctionExamples.scala index 40430035..9e44d964 100644 --- a/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/examples/GFunctionExamples.scala +++ b/cyfra-examples/src/main/scala/io/computenode/cyfra/samples/examples/GFunctionExamples.scala @@ -163,6 +163,36 @@ object GFunctionExamples: println(s"Saved to examples_output/julia.png") println() + def example4_FibonacciSequence(): Unit = + // Test the Fibonacci-like GSeq from documentation using Vec2[Float32] + // Pattern: GSeq.gen[Vec2[Float32]](init, pair => vec2(pair.y, pair.x + pair.y)) + // Generates: (0,1), (1,1), (1,2), (2,3), (3,5), (5,8), ... + // fib(0), fib(1), fib(2), fib(3), fib(4), fib(5), ... + val fibonacciNth: GFunction[GStruct.Empty, Float32, Float32] = GFunction: _ => + // Generate Fibonacci-like pairs: (a, b) -> (b, a+b) + val fibonacci = GSeq.gen[Vec2[Float32]]((0.0f, 1.0f), pair => (pair.y, pair.x + pair.y)) + // limit(n) gives n pairs, last.x = fib(n-1) + // So limit(11).last.x = fib(10) = 55 + fibonacci.limit(11).lastOr(vec2(0.0f, 0.0f)).x + + val input = Array.fill(256)(0.0f) // dummy input + + println("Example 4: Fibonacci Sequence (GSeq.gen with Vec2)") + println("Testing: GSeq.gen[Vec2[Float32]](vec2(0, 1), pair => vec2(pair.y, pair.x + pair.y))") + println("Computing fib(10) on GPU using limit(11).last.x ...") + + val results: Array[Float] = fibonacciNth.run(input) + + // Sequence with limit(11): (0,1), (1,1), (1,2), (2,3), (3,5), (5,8), (8,13), (13,21), (21,34), (34,55), (55,89) + // last.x = 55 = fib(10) + val expected = 55.0f + println(s"Result: fib(10) = ${results(0).toInt}") + println(s"Expected: ${expected.toInt}") + + val correct = Math.abs(results(0) - expected) < 0.001f + println(s"Result correct: $correct") + println() + case class TransformConfig(scale: Float32, offset: Float32) extends GStruct[TransformConfig] def example8_Uniforms(): Unit = @@ -203,6 +233,7 @@ object GFunctionExamples: example1_HelloGpu() example2_VectorOperations() example3_CustomStructs() + example4_FibonacciSequence() example6_Mandelbrot() example7_JuliaSet() example8_Uniforms() diff --git a/cyfra-llama/compare_incremental.py b/cyfra-llama/compare_incremental.py new file mode 100644 index 00000000..a4848e86 --- /dev/null +++ b/cyfra-llama/compare_incremental.py @@ -0,0 +1,51 @@ +"""Compare llama.cpp predictions for incremental generation.""" +from llama_cpp import Llama +import numpy as np + +# Load model +print("Loading model...") +llm = Llama( + model_path="cyfra-llama/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", + n_ctx=64, + verbose=False, + logits_all=True, # Enable all logits +) + +def get_top_predictions(llm, tokens, pos): + """Get top predictions for a specific position.""" + # Reset and evaluate + llm.reset() + llm.eval(tokens) + + # Get logits for the requested position + logits = np.array(llm._scores[pos]) + + top_indices = np.argsort(logits)[::-1][:10] + print(f"Top-10 predictions for position {pos}:") + for idx in top_indices: + try: + token_str = llm.detokenize([int(idx)]).decode('utf-8', errors='replace') + token_str = token_str.encode('ascii', errors='replace').decode('ascii') + except: + token_str = f"[{idx}]" + print(f" Token {idx:5d} ({token_str:>10s}): logit={logits[idx]:10.4f}") + + print(f" Stats: min={logits.min():.4f}, max={logits.max():.4f}, mean={logits.mean():.4f}") + return logits + +# Test 1: [BOS, Hello] -> predict next +tokens_1 = [1, 15043] # BOS + Hello +print(f"\n=== Sequence 1: {tokens_1} (BOS + Hello) ===") +logits_1 = get_top_predictions(llm, tokens_1, 1) + +# Test 2: [BOS, Hello, ,] -> predict next +tokens_2 = [1, 15043, 29892] # BOS + Hello + , +print(f"\n=== Sequence 2: {tokens_2} (BOS + Hello + ,) ===") +logits_2 = get_top_predictions(llm, tokens_2, 2) + +# Also compare logits for position 1 in both sequences (should be the same!) +print(f"\n=== Position 1 logits in sequence 2 (should match sequence 1) ===") +llm.reset() +llm.eval(tokens_2) +logits_2_pos1 = np.array(llm._scores[1]) +print(f" max diff between seq1 pos1 and seq2 pos1: {np.abs(logits_1 - logits_2_pos1).max():.6f}") diff --git a/cyfra-llama/compare_logits.py b/cyfra-llama/compare_logits.py new file mode 100644 index 00000000..8613bcec --- /dev/null +++ b/cyfra-llama/compare_logits.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +"""Compare logits from llama-cpp-python with our implementation.""" + +import numpy as np +from llama_cpp import Llama + +def main(): + model_path = "cyfra-llama/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" + + print(f"Loading model from {model_path}...") + llm = Llama( + model_path=model_path, + n_ctx=32, + n_batch=32, + verbose=True, + logits_all=True, # Get logits for all tokens + ) + + # Test tokens: BOS (1) + "Hello" token + prompt = "Hello" + print(f"\nPrompt: '{prompt}'") + + # Tokenize + tokens = llm.tokenize(prompt.encode(), add_bos=True) + print(f"Tokens: {tokens}") + + # Evaluate and get logits + llm.reset() + llm.eval(tokens) + + # Get logits for the last token + logits = llm.scores[len(tokens) - 1] + logits_array = np.array(logits, dtype=np.float32) + + print(f"\nLogits shape: {logits_array.shape}") + print(f"Logits stats: min={logits_array.min():.4f}, max={logits_array.max():.4f}, mean={logits_array.mean():.4f}, std={logits_array.std():.4f}") + print(f"Logits sum: {logits_array.sum():.4f}") + + # Get top 5 predictions + top_indices = np.argsort(logits_array)[-5:][::-1] + print("\nTop 5 predictions:") + for idx in top_indices: + token_str = llm.detokenize([idx]).decode('utf-8', errors='replace') + print(f" {idx}: '{token_str}' (score={logits_array[idx]:.2f})") + + # Print first and last few logits for comparison + print(f"\nFirst 10 logits: {logits_array[:10]}") + print(f"Last 10 logits: {logits_array[-10:]}") + + # Also test with just "Hello" (no BOS) + print("\n" + "="*60) + print("Testing single token (15043 = 'Hello')...") + + llm.reset() + llm.eval([15043]) # Just the "Hello" token + + logits2 = llm.scores[0] + logits2_array = np.array(logits2, dtype=np.float32) + + print(f"Logits stats: min={logits2_array.min():.4f}, max={logits2_array.max():.4f}, mean={logits2_array.mean():.4f}, std={logits2_array.std():.4f}") + + # Get top 5 + top_indices2 = np.argsort(logits2_array)[-5:][::-1] + print("\nTop 5 predictions:") + for idx in top_indices2: + token_str = llm.detokenize([idx]).decode('utf-8', errors='replace') + print(f" {idx}: '{token_str}' (score={logits2_array[idx]:.2f})") + +if __name__ == "__main__": + main() diff --git a/cyfra-llama/compare_with_llama_cpp.py b/cyfra-llama/compare_with_llama_cpp.py new file mode 100644 index 00000000..273633e7 --- /dev/null +++ b/cyfra-llama/compare_with_llama_cpp.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +""" +Compare GPU logits against llama.cpp reference. +Run this after running LayerByLayerDebugTest to see the actual llama.cpp output. +""" + +from llama_cpp import Llama +import numpy as np + +MODEL_PATH = "cyfra-llama/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" + +def main(): + print("Loading model via llama.cpp...") + llm = Llama( + model_path=MODEL_PATH, + n_ctx=512, + n_batch=512, + verbose=False, + logits_all=True, # Get logits for all positions + ) + + # Test token: 15043 = "Hello" + # We want to get the logits for predicting what comes after "Hello" + tokens = [1, 15043] # BOS + "Hello" + + print(f"\nTokens: {tokens}") + print("Running llama.cpp forward pass...") + + # Run forward pass + llm.reset() + llm.eval(tokens) + + # Get logits for last position (predicting what comes after "Hello") + logits = np.array(llm.scores[len(tokens) - 1]) + + print(f"\n=== llama.cpp Reference Logits for token 15043 (Hello) ===") + print(f"Logits shape: {logits.shape}") + print(f"min={logits.min():.4f}, max={logits.max():.4f}, mean={logits.mean():.4f}, std={logits.std():.4f}") + + # Top-10 tokens + top_indices = np.argsort(logits)[::-1][:10] + print("\nTop-10 predicted tokens:") + for idx in top_indices: + token_str = llm.detokenize([idx]).decode('utf-8', errors='replace') + print(f" Token {idx:5d} ({token_str:>10s}): logit={logits[idx]:10.4f}") + + # Argmax + predicted = np.argmax(logits) + predicted_str = llm.detokenize([predicted]).decode('utf-8', errors='replace') + print(f"\nPredicted next token: {predicted} ({predicted_str})") + + # Also test T=2 case: "Hello," + print("\n" + "="*60) + print("Testing T=2: [BOS, Hello, ,]") + tokens2 = [1, 15043, 29892] # BOS + "Hello" + "," + + llm.reset() + llm.eval(tokens2) + + logits2 = np.array(llm.scores[len(tokens2) - 1]) + print(f"\n=== llama.cpp Reference Logits for 'Hello,' (predicting 3rd token) ===") + print(f"min={logits2.min():.4f}, max={logits2.max():.4f}, mean={logits2.mean():.4f}, std={logits2.std():.4f}") + + top_indices2 = np.argsort(logits2)[::-1][:10] + print("\nTop-10 predicted tokens:") + for idx in top_indices2: + token_str = llm.detokenize([idx]).decode('utf-8', errors='replace') + print(f" Token {idx:5d} ({token_str:>10s}): logit={logits2[idx]:10.4f}") + +if __name__ == "__main__": + main() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/Runner.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/Runner.scala new file mode 100644 index 00000000..eb6bc1dc --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/Runner.scala @@ -0,0 +1,250 @@ +package io.computenode.cyfra.llama + +import io.computenode.cyfra.llama.inference.LlamaInference +import io.computenode.cyfra.llama.model.LlamaModel +import io.computenode.cyfra.llama.pipeline.LlamaPipeline +import io.computenode.cyfra.llama.tokenizer.LlamaTokenizer +import io.computenode.cyfra.llama.util.Logger +import io.computenode.cyfra.runtime.VkCyfraRuntime + +import java.nio.file.{Files, Paths} +import scala.io.StdIn + +/** Llama model runner with F16 and F32 pipeline support. + * + * Usage: + * runner --model path/to/model.gguf --type f16 --interactive + * runner --model path/to/model.gguf --type f32 --prompt "Hello world" + */ +object Runner: + + case class Config( + modelPath: String = "", + modelType: String = "auto", // "f16", "f32", or "auto" + interactive: Boolean = false, + measure: Boolean = false, + batch: Boolean = false, // Buffer output, print at end + prompt: Option[String] = None, + maxTokens: Int = 500, + temperature: Float = 0.7f, + topP: Float = 0.9f, + warmupRuns: Int = 3, + benchmarkRuns: Int = 5, + ) + + def main(args: Array[String]): Unit = + val config = parseArgs(args) + + if config.modelPath.isEmpty then + printUsage() + return + + if !Files.exists(Paths.get(config.modelPath)) then + System.err.println(s"Error: Model not found: ${config.modelPath}") + return + + val resolvedType = if config.modelType == "auto" then + if config.modelPath.toLowerCase.contains("f16") then "f16" else "f32" + else config.modelType + + println(s"Cyfra Llama Runner") + println(s"Model: ${config.modelPath}") + println(s"Type: $resolvedType") + + VkCyfraRuntime.using: + val model = LlamaModel.fromGGUF(Paths.get(config.modelPath)) + val tokenizer = LlamaTokenizer(model.gguf) + val inference = new LlamaInference(model, maxT = 1024) + + val pipeline: LlamaPipeline = resolvedType match + case "f16" => inference.getF16Pipeline + case _ => + System.err.println(s"Unknown model type: $resolvedType") + return + + println(s"Ready: ${model.config.hiddenSize}d, ${model.config.numHiddenLayers}L\n") + + if config.measure then + runBenchmark(pipeline, tokenizer, config) + else if config.interactive then + runInteractive(pipeline, tokenizer, config) + else if config.prompt.isDefined then + runOnce(pipeline, tokenizer, config.prompt.get, config) + else + printUsage() + + private def runInteractive(pipeline: LlamaPipeline, tokenizer: LlamaTokenizer, config: Config): Unit = + println("Interactive mode. Commands: quit, exit") + println("-" * 40) + + var running = true + while running do + print("\nYou: ") + System.out.flush() + val userInput = StdIn.readLine() + + if userInput == null || userInput.trim.toLowerCase == "quit" || userInput.trim.toLowerCase == "exit" then + running = false + else if userInput.trim.nonEmpty then + val prompt = s"<|user|>\n${userInput.trim}\n<|assistant|>\n" + runGeneration(pipeline, tokenizer, prompt, config) + + private def runOnce(pipeline: LlamaPipeline, tokenizer: LlamaTokenizer, prompt: String, config: Config): Unit = + println(s"Prompt: $prompt\n") + runGeneration(pipeline, tokenizer, prompt, config) + + private def runBenchmark(pipeline: LlamaPipeline, tokenizer: LlamaTokenizer, config: Config): Unit = + val prompt = config.prompt.getOrElse("Once upon a time") + val tokens = tokenizer.encode(prompt) + + println(s"Benchmark: '$prompt' -> ${config.maxTokens} tokens") + println(s"Warmup: ${config.warmupRuns} runs, Benchmark: ${config.benchmarkRuns} runs") + println(f"Sampling: temperature=${config.temperature}%.2f, top_p=${config.topP}%.2f\n") + + // Warmup + print("Warming up: ") + for i <- 1 to config.warmupRuns do + pipeline.generate( + tokens, config.maxTokens, + temperature = config.temperature, + topP = config.topP, + stopTokens = Set(tokenizer.eosToken), + ) + print(s"$i ") + System.out.flush() + println("done\n") + + // Benchmark runs + println("Benchmark runs:") + + val (decoded, stats, lastGenerated) = (1 to config.benchmarkRuns).map: i => + val generated = pipeline.generate( + tokens, config.maxTokens, + temperature = config.temperature, + topP = config.topP, + stopTokens = Set(tokenizer.eosToken), + ) + val decoded = tokenizer.decode(generated) + val s = pipeline.lastStats.get + println(f" Run $i: ${s.generatedTokens} tokens, generate ${s.decodeTokPerSec}%.1f tok/s") + (decoded, s, generated) + .unzip3 + + val avgDecode = stats.map(_.decodeTokPerSec).sum / stats.length + val bestDecode = stats.map(_.decodeTokPerSec).max + + println() + println("Last generation:") + println(decoded.last.toString) + println(f"Average: $avgDecode%.1f tok/s") + println(f"Best: $bestDecode%.1f tok/s") + + private def runGeneration(pipeline: LlamaPipeline, tokenizer: LlamaTokenizer, prompt: String, config: Config): Unit = + val tokens = tokenizer.encode(prompt) + + if config.batch then + // Batch mode: print at end + val generated = pipeline.generate( + promptTokens = tokens, + maxNewTokens = config.maxTokens, + temperature = config.temperature, + topP = config.topP, + stopTokens = Set(tokenizer.eosToken), + ) + val decoded = tokenizer.decode(generated) + println(s"Output: $decoded") + else + // Streaming mode: print tokens as they arrive + print("Output: ") + System.out.flush() + val generated = pipeline.generate( + promptTokens = tokens, + maxNewTokens = config.maxTokens, + temperature = config.temperature, + topP = config.topP, + onToken = token => + val text = tokenizer.decodeToken(token) + if !text.contains("") && !text.contains("<|") then + print(text) + System.out.flush() + , + stopTokens = Set(tokenizer.eosToken), + ) + println() + pipeline.lastStats match + case Some(stats) => + println(f"[${generated.length} tokens, generate ${stats.decodeTokPerSec}%.1f tok/s]") + case None => + println(f"[${generated.length} tokens]") + + private def parseArgs(args: Array[String]): Config = + var config = Config() + var i = 0 + while i < args.length do + args(i) match + case "--model" | "-m" if i + 1 < args.length => + config = config.copy(modelPath = args(i + 1)) + i += 2 + case "--type" | "-t" if i + 1 < args.length => + config = config.copy(modelType = args(i + 1).toLowerCase) + i += 2 + case "--interactive" | "-i" => + config = config.copy(interactive = true) + i += 1 + case "--measure" => + config = config.copy(measure = true) + i += 1 + case "--batch" | "-b" => + config = config.copy(batch = true) + i += 1 + case "--warmup" if i + 1 < args.length => + config = config.copy(warmupRuns = args(i + 1).toInt) + i += 2 + case "--runs" if i + 1 < args.length => + config = config.copy(benchmarkRuns = args(i + 1).toInt) + i += 2 + case "--prompt" | "-p" if i + 1 < args.length => + config = config.copy(prompt = Some(args(i + 1))) + i += 2 + case "--max-tokens" | "-n" if i + 1 < args.length => + config = config.copy(maxTokens = args(i + 1).toInt) + i += 2 + case "--temperature" if i + 1 < args.length => + config = config.copy(temperature = args(i + 1).toFloat) + i += 2 + case "--top-p" if i + 1 < args.length => + config = config.copy(topP = args(i + 1).toFloat) + i += 2 + case arg if !arg.startsWith("-") && config.modelPath.isEmpty => + config = config.copy(modelPath = arg) + i += 1 + case other => + System.err.println(s"Unknown argument: $other") + i += 1 + config + + private def printUsage(): Unit = + println(""" + |Usage: runner [OPTIONS] [MODEL_PATH] + | + |Modes: + | -i, --interactive Interactive chat mode (streaming output) + | -p, --prompt TEXT Single prompt and exit (streaming output) + | --measure Benchmark mode (no output, multiple runs) + | + |Options: + | -m, --model PATH Path to GGUF model file + | -t, --type TYPE Model type: f16, f32, or auto (default: auto) + | -b, --batch Buffer output, print at end (faster) + | -n, --max-tokens N Maximum tokens to generate (default: 500) + | --temperature FLOAT Sampling temperature (default: 0.7) + | --top-p FLOAT Top-p sampling threshold (default: 0.9) + | --warmup N Warmup runs for benchmark (default: 3) + | --runs N Benchmark runs (default: 5) + | + |Examples: + | runner -m model.gguf -t f16 -i + | runner -m model.gguf -p "Hello world" -b -n 100 + | runner -m model.gguf --measure -n 128 + |""".stripMargin) + diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/gguf/Dequantize.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/gguf/Dequantize.scala new file mode 100644 index 00000000..ff3908d0 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/gguf/Dequantize.scala @@ -0,0 +1,204 @@ +package io.computenode.cyfra.llama.gguf + +import java.nio.{ByteBuffer, ByteOrder} + +/** Dequantization functions for GGUF quantized tensors. + * + * Based on llama.cpp's ggml-quants.c + */ +object Dequantize: + + val QK_K = 256 // Block size for K-quants + + /** Convert half-precision float16 to float32. + * + * IEEE 754 half-precision: 1 sign bit, 5 exponent bits, 10 mantissa bits + */ + def fp16ToFp32(h: Short): Float = + val sign = (h >> 15) & 1 + val exp = (h >> 10) & 0x1F + val mant = h & 0x3FF + + if exp == 0 then + // Denormalized or zero + if mant == 0 then + if sign == 1 then -0.0f else 0.0f + else + // Denormalized number + val f = mant.toFloat / 1024.0f + val result = f * math.pow(2, -14).toFloat + if sign == 1 then -result else result + else if exp == 31 then + // Infinity or NaN + if mant == 0 then + if sign == 1 then Float.NegativeInfinity else Float.PositiveInfinity + else + Float.NaN + else + // Normalized number + val f = 1.0f + mant.toFloat / 1024.0f + val result = f * math.pow(2, exp - 15).toFloat + if sign == 1 then -result else result + + /** Dequantize Q4_K block to float32. + * + * Q4_K format: + * - 256 values per block + * - 2x float16 for d and dmin + * - 12 bytes for scales (6-bit each, packed) + * - 128 bytes for quantized values (4-bit each, packed) + * - Total: 144 bytes per block + */ + def dequantizeQ4K(data: Array[Byte], numElements: Long): Array[Float] = + val numBlocks = (numElements / QK_K).toInt + val result = new Array[Float](numElements.toInt) + val buf = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN) + + var resultIdx = 0 + for blockIdx <- 0 until numBlocks do + val blockStart = blockIdx * 144 + + // Read d and dmin (fp16) + val dHalf = buf.getShort(blockStart) + val dminHalf = buf.getShort(blockStart + 2) + val d = fp16ToFp32(dHalf) + val dmin = fp16ToFp32(dminHalf) + + // Read scales (12 bytes, 6-bit values packed) + val scales = new Array[Byte](12) + for i <- 0 until 12 do + scales(i) = buf.get(blockStart + 4 + i) + + // Read quantized values (128 bytes, 4-bit packed) + val qs = new Array[Byte](128) + for i <- 0 until 128 do + qs(i) = buf.get(blockStart + 16 + i) + + // Dequantize 256 values in groups of 64 + var is = 0 + var qsIdx = 0 + for j <- 0 until 4 do // 4 groups of 64 + // Get scale and min for this group (two sub-groups of 32) + val (sc1, m1) = getScaleMinK4(is, scales) + val (sc2, m2) = getScaleMinK4(is + 1, scales) + + val d1 = d * sc1 + val m1Val = dmin * m1 + val d2 = d * sc2 + val m2Val = dmin * m2 + + // First 32 values (low nibble) + for l <- 0 until 32 do + val q = qs(qsIdx + l) & 0x0F + result(resultIdx) = d1 * q - m1Val + resultIdx += 1 + + // Second 32 values (high nibble) + for l <- 0 until 32 do + val q = (qs(qsIdx + l) >> 4) & 0x0F + result(resultIdx) = d2 * q - m2Val + resultIdx += 1 + + qsIdx += 32 + is += 2 + + result + + /** Get scale and min from packed 6-bit values in Q4_K scales array. + * + * Matches llama.cpp's get_scale_min_k4 implementation exactly. + * scales array is 12 bytes, j ranges 0-7. + * + * IMPORTANT: Use & 0xFF to convert signed bytes to unsigned before shifting, + * otherwise Java's signed byte extension causes incorrect results when bit 7 is set. + */ + private def getScaleMinK4(j: Int, scales: Array[Byte]): (Float, Float) = + if j < 4 then + // Simple 6-bit extraction from lower bytes + val d = (scales(j) & 0x3F).toFloat + val m = (scales(j + 4) & 0x3F).toFloat + (d, m) + else + // Combine bits from different positions - use & 0xFF for unsigned interpretation + val sj4 = scales(j + 4) & 0xFF // scales[j+4] as unsigned + val sjm4 = scales(j - 4) & 0xFF // scales[j-4] as unsigned + val sj = scales(j) & 0xFF // scales[j] as unsigned + val d = ((sj4 & 0x0F) | ((sjm4 >> 6) << 4)).toFloat + val m = (((sj4 >> 4) & 0x0F) | ((sj >> 6) << 4)).toFloat + (d, m) + + /** Dequantize Q6_K block to float32. + * + * Q6_K format (matches llama.cpp exactly): + * - 256 values per block + * - 128 bytes for low 4 bits (ql) + * - 64 bytes for high 2 bits (qh) + * - 16 bytes for scales (int8) + * - 2 bytes for d (fp16) + * - Total: 210 bytes per block + */ + def dequantizeQ6K(data: Array[Byte], numElements: Long): Array[Float] = + val numBlocks = (numElements / QK_K).toInt + val result = new Array[Float](numElements.toInt) + val buf = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN) + + for blockIdx <- 0 until numBlocks do + val blockStart = blockIdx * 210 + val blockResultStart = blockIdx * QK_K + + // Read d (fp16) at offset 208 + val d = fp16ToFp32(buf.getShort(blockStart + 208)) + + // Two halves: n=0 (values 0-127), n=1 (values 128-255) + var qlOffset = 0 + var qhOffset = 0 + var scOffset = 0 + var yOffset = 0 + + for n <- 0 until 2 do + // Process 128 values in this half + for l <- 0 until 32 do + val is = l / 16 // Scale index within this 128-value block + + // Read ql values + val ql0 = buf.get(blockStart + qlOffset + l) & 0xFF + val ql32 = buf.get(blockStart + qlOffset + l + 32) & 0xFF + + // Read qh value + val qhVal = buf.get(blockStart + 128 + qhOffset + l) & 0xFF + + // Read scales (int8, so need sign extension) + val sc0 = buf.get(blockStart + 192 + scOffset + is + 0).toInt + val sc2 = buf.get(blockStart + 192 + scOffset + is + 2).toInt + val sc4 = buf.get(blockStart + 192 + scOffset + is + 4).toInt + val sc6 = buf.get(blockStart + 192 + scOffset + is + 6).toInt + + // Compute 4 quantized values + val q1 = ((ql0 & 0x0F) | (((qhVal >> 0) & 3) << 4)) - 32 + val q2 = ((ql32 & 0x0F) | (((qhVal >> 2) & 3) << 4)) - 32 + val q3 = ((ql0 >> 4) | (((qhVal >> 4) & 3) << 4)) - 32 + val q4 = ((ql32 >> 4) | (((qhVal >> 6) & 3) << 4)) - 32 + + // Store 4 dequantized values + result(blockResultStart + yOffset + l + 0) = d * sc0 * q1 + result(blockResultStart + yOffset + l + 32) = d * sc2 * q2 + result(blockResultStart + yOffset + l + 64) = d * sc4 * q3 + result(blockResultStart + yOffset + l + 96) = d * sc6 * q4 + + // Move to next half + qlOffset += 64 + qhOffset += 32 + scOffset += 8 + yOffset += 128 + + result + + /** Dequantize F16 to F32. */ + def dequantizeF16(data: Array[Byte], numElements: Long): Array[Float] = + val result = new Array[Float](numElements.toInt) + val buf = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN) + + for i <- 0 until numElements.toInt do + result(i) = fp16ToFp32(buf.getShort(i * 2)) + + result diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/gguf/GGUFReader.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/gguf/GGUFReader.scala new file mode 100644 index 00000000..2358a809 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/gguf/GGUFReader.scala @@ -0,0 +1,477 @@ +package io.computenode.cyfra.llama.gguf + +import java.io.RandomAccessFile +import java.nio.{ByteBuffer, ByteOrder} +import java.nio.channels.FileChannel +import java.nio.file.Path +import scala.collection.mutable + +/** GGUF (GGML Universal File) format reader. + * + * GGUF is llama.cpp's model format. This reader parses the file header, + * metadata key-value pairs, and tensor information. + * + * File structure: + * - Magic: 4 bytes ("GGUF" = 0x46554747) + * - Version: uint32 (currently 3) + * - Tensor count: uint64 + * - KV count: uint64 + * - Key-value pairs (metadata) + * - Tensor info (name, dimensions, type, offset) + * - Padding to alignment (default 32 bytes) + * - Tensor data + */ +object GGUFReader: + val GGUF_MAGIC: Int = 0x46554747 // "GGUF" + val GGUF_VERSION: Int = 3 + val DEFAULT_ALIGNMENT: Int = 32 + + /** Value types in GGUF metadata. */ + enum ValueType(val id: Int): + case UINT8 extends ValueType(0) + case INT8 extends ValueType(1) + case UINT16 extends ValueType(2) + case INT16 extends ValueType(3) + case UINT32 extends ValueType(4) + case INT32 extends ValueType(5) + case FLOAT32 extends ValueType(6) + case BOOL extends ValueType(7) + case STRING extends ValueType(8) + case ARRAY extends ValueType(9) + case UINT64 extends ValueType(10) + case INT64 extends ValueType(11) + case FLOAT64 extends ValueType(12) + + object ValueType: + def fromId(id: Int): ValueType = + ValueType.values.find(_.id == id).getOrElse( + throw new IllegalArgumentException(s"Unknown value type: $id") + ) + + /** Quantization types for tensors. */ + enum QuantType(val id: Int, val blockSize: Int, val bytesPerBlock: Int): + case F32 extends QuantType(0, 1, 4) + case F16 extends QuantType(1, 1, 2) + case Q4_0 extends QuantType(2, 32, 18) + case Q4_1 extends QuantType(3, 32, 20) + case Q5_0 extends QuantType(6, 32, 22) + case Q5_1 extends QuantType(7, 32, 24) + case Q8_0 extends QuantType(8, 32, 34) + case Q8_1 extends QuantType(9, 32, 36) + case Q2_K extends QuantType(10, 256, 84) + case Q3_K extends QuantType(11, 256, 110) + case Q4_K extends QuantType(12, 256, 144) + case Q5_K extends QuantType(13, 256, 176) + case Q6_K extends QuantType(14, 256, 210) + case Q8_K extends QuantType(15, 256, 292) + case IQ2_XXS extends QuantType(16, 256, 66) + case IQ2_XS extends QuantType(17, 256, 74) + case IQ3_XXS extends QuantType(18, 256, 98) + case IQ1_S extends QuantType(19, 256, 50) + case IQ4_NL extends QuantType(20, 32, 18) + case IQ3_S extends QuantType(21, 256, 110) + case IQ2_S extends QuantType(22, 256, 82) + case IQ4_XS extends QuantType(23, 256, 136) + case BF16 extends QuantType(30, 1, 2) + + object QuantType: + def fromId(id: Int): QuantType = + QuantType.values.find(_.id == id).getOrElse( + throw new IllegalArgumentException(s"Unknown quant type: $id") + ) + + /** Metadata value can be various types. */ + sealed trait MetaValue + case class MetaUInt8(value: Byte) extends MetaValue + case class MetaInt8(value: Byte) extends MetaValue + case class MetaUInt16(value: Short) extends MetaValue + case class MetaInt16(value: Short) extends MetaValue + case class MetaUInt32(value: Int) extends MetaValue + case class MetaInt32(value: Int) extends MetaValue + case class MetaFloat32(value: Float) extends MetaValue + case class MetaBool(value: Boolean) extends MetaValue + case class MetaString(value: String) extends MetaValue + case class MetaUInt64(value: Long) extends MetaValue + case class MetaInt64(value: Long) extends MetaValue + case class MetaFloat64(value: Double) extends MetaValue + case class MetaArray(values: Seq[MetaValue]) extends MetaValue + + /** Tensor information from GGUF file. */ + case class TensorInfo( + name: String, + shape: Array[Long], + quantType: QuantType, + offset: Long, + ): + def numElements: Long = shape.product + def numBytes: Long = + val blocks = (numElements + quantType.blockSize - 1) / quantType.blockSize + blocks * quantType.bytesPerBlock + + /** Parsed GGUF file. */ + case class GGUFFile( + version: Int, + metadata: Map[String, MetaValue], + tensors: Seq[TensorInfo], + dataOffset: Long, + channel: FileChannel, + ): + def close(): Unit = channel.close() + + /** Get metadata value as string. */ + def getString(key: String): Option[String] = metadata.get(key).collect { case MetaString(v) => v } + + /** Get metadata value as int. */ + def getInt(key: String): Option[Int] = metadata.get(key).collect { + case MetaUInt32(v) => v + case MetaInt32(v) => v + case MetaUInt8(v) => v.toInt & 0xFF + case MetaInt8(v) => v.toInt + } + + /** Get metadata value as long. */ + def getLong(key: String): Option[Long] = metadata.get(key).collect { + case MetaUInt64(v) => v + case MetaInt64(v) => v + case MetaUInt32(v) => v.toLong & 0xFFFFFFFFL + case MetaInt32(v) => v.toLong + } + + /** Get metadata value as float. */ + def getFloat(key: String): Option[Float] = metadata.get(key).collect { + case MetaFloat32(v) => v + case MetaFloat64(v) => v.toFloat + } + + /** Get metadata value as string array. */ + def getStringArray(key: String): Option[Array[String]] = metadata.get(key).collect { + case MetaArray(vals) => vals.collect { case MetaString(s) => s }.toArray + } + + /** Get metadata value as float array. */ + def getFloatArray(key: String): Option[Array[Float]] = metadata.get(key).collect { + case MetaArray(vals) => vals.collect { case MetaFloat32(f) => f }.toArray + } + + /** Get tensor by name. */ + def getTensor(name: String): Option[TensorInfo] = tensors.find(_.name == name) + + /** Read tensor data as float array. Only works for F32 tensors. */ + def readTensorF32(tensor: TensorInfo): Array[Float] = + require(tensor.quantType == QuantType.F32, s"Tensor ${tensor.name} is ${tensor.quantType}, not F32") + val buffer = ByteBuffer.allocate(tensor.numBytes.toInt).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buffer, dataOffset + tensor.offset) + buffer.flip() + val result = Array.ofDim[Float](tensor.numElements.toInt) + buffer.asFloatBuffer().get(result) + result + + /** Read tensor data as raw bytes. */ + def readTensorBytes(tensor: TensorInfo): Array[Byte] = + val buffer = ByteBuffer.allocate(tensor.numBytes.toInt) + channel.read(buffer, dataOffset + tensor.offset) + buffer.flip() + val result = Array.ofDim[Byte](tensor.numBytes.toInt) + buffer.get(result) + result + + /** Read tensor data directly into a ByteBuffer for GPU upload. + * Returns little-endian ordered buffer suitable for GBuffer[UInt32]. + */ + def readTensorToBuffer(tensor: TensorInfo): ByteBuffer = + val buffer = ByteBuffer.allocateDirect(tensor.numBytes.toInt).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buffer, dataOffset + tensor.offset) + buffer.rewind() + buffer + + /** Read Q4_K tensor as UInt32 array for GPU upload. + * Q4_K: 144 bytes = 36 UInt32 per 256-element block. + */ + def readTensorQ4KAsUInt32(tensor: TensorInfo): Array[Int] = + require(tensor.quantType == QuantType.Q4_K, s"Tensor ${tensor.name} is ${tensor.quantType}, not Q4_K") + val bytes = readTensorBytes(tensor) + val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN) + val numUInt32 = tensor.numBytes.toInt / 4 + val result = Array.ofDim[Int](numUInt32) + buf.asIntBuffer().get(result) + result + + /** Read any quantized tensor as UInt32 array for GPU upload. + * Use this for tensors that may have different quantization types. + */ + def readTensorAsUInt32(tensor: TensorInfo): Array[Int] = + val bytes = readTensorBytes(tensor) + val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN) + val numUInt32 = tensor.numBytes.toInt / 4 + val result = Array.ofDim[Int](numUInt32) + buf.asIntBuffer().get(result) + result + + /** Read Q6_K tensor as UInt32 array for GPU upload. + * Q6_K: 210 bytes per 256-element block. + * + * Since 210 is not divisible by 4, we pad each block to 212 bytes (53 uint32) + * for GPU alignment. + */ + def readTensorQ6KAsUInt32(tensor: TensorInfo): Array[Int] = + require(tensor.quantType == QuantType.Q6_K, s"Tensor ${tensor.name} is ${tensor.quantType}, not Q6_K") + val bytes = readTensorBytes(tensor) + val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN) + + // Q6_K blocks are 210 bytes. We pack them as-is, reading at byte level on GPU. + // Just pad the total to a multiple of 4 for uint32 alignment. + val numUInt32 = (tensor.numBytes.toInt + 3) / 4 + val result = Array.ofDim[Int](numUInt32) + + // Copy bytes into uint32 array + var i = 0 + while i < tensor.numBytes.toInt / 4 do + result(i) = buf.getInt(i * 4) + i += 1 + + // Handle remaining bytes (if any) + if tensor.numBytes.toInt % 4 != 0 then + var lastWord = 0 + var j = 0 + while j < tensor.numBytes.toInt % 4 do + lastWord |= (bytes(tensor.numBytes.toInt - tensor.numBytes.toInt % 4 + j) & 0xFF) << (j * 8) + j += 1 + result(i) = lastWord + + result + + /** Read and dequantize tensor to Float32. + * + * Supports F32, F16, Q4_K, and Q6_K quantization types. + */ + def readTensorDequantized(tensor: TensorInfo): Array[Float] = + val bytes = readTensorBytes(tensor) + tensor.quantType match + case QuantType.F32 => + val buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN) + val result = Array.ofDim[Float](tensor.numElements.toInt) + buf.asFloatBuffer().get(result) + result + case QuantType.F16 => + Dequantize.dequantizeF16(bytes, tensor.numElements) + case QuantType.Q4_K => + Dequantize.dequantizeQ4K(bytes, tensor.numElements) + case QuantType.Q6_K => + Dequantize.dequantizeQ6K(bytes, tensor.numElements) + case other => + throw new UnsupportedOperationException(s"Dequantization not implemented for $other") + + /** Read F16 tensor as raw bytes without conversion. + * + * Returns the raw F16 bytes (2 bytes per element) for direct GPU upload. + * This avoids F32 conversion, saving 2x memory. + */ + def readTensorF16Bytes(tensor: TensorInfo): Array[Byte] = + require(tensor.quantType == QuantType.F16, s"Expected F16 tensor, got ${tensor.quantType}") + readTensorBytes(tensor) + + /** Read GGUF file from path. */ + def read(path: Path): GGUFFile = + val raf = new RandomAccessFile(path.toFile, "r") + val channel = raf.getChannel + + // Read header + val headerBuf = ByteBuffer.allocate(24).order(ByteOrder.LITTLE_ENDIAN) + channel.read(headerBuf, 0) + headerBuf.flip() + + val magic = headerBuf.getInt + if magic != GGUF_MAGIC then + throw new IllegalArgumentException(s"Invalid GGUF magic: ${magic.toHexString}, expected ${GGUF_MAGIC.toHexString}") + + val version = headerBuf.getInt + if version != 2 && version != 3 then + throw new IllegalArgumentException(s"Unsupported GGUF version: $version") + + val tensorCount = headerBuf.getLong + val kvCount = headerBuf.getLong + + // Parse key-value pairs + var offset = 24L + val metadata = mutable.Map[String, MetaValue]() + + for _ <- 0L until kvCount do + val (key, value, newOffset) = readKV(channel, offset) + metadata(key) = value + offset = newOffset + + // Parse tensor info + val tensors = mutable.ArrayBuffer[TensorInfo]() + for _ <- 0L until tensorCount do + val (tensor, newOffset) = readTensorInfo(channel, offset) + tensors += tensor + offset = newOffset + + // Compute data offset with alignment + val alignment = metadata.get("general.alignment").collect { case MetaUInt32(v) => v }.getOrElse(DEFAULT_ALIGNMENT) + val padding = offset % alignment + val dataOffset = if padding == 0 then offset else offset + alignment - padding + + GGUFFile(version, metadata.toMap, tensors.toSeq, dataOffset, channel) + + private def readKV(channel: FileChannel, offset: Long): (String, MetaValue, Long) = + var pos = offset + + // Read key (string) + val (key, keyEndPos) = readString(channel, pos) + pos = keyEndPos + + // Read value type + val typeBuf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN) + channel.read(typeBuf, pos) + typeBuf.flip() + val valueTypeId = typeBuf.getInt + pos += 4 + + val valueType = ValueType.fromId(valueTypeId) + val (value, valueEndPos) = readValue(channel, pos, valueType) + + (key, value, valueEndPos) + + private def readValue(channel: FileChannel, offset: Long, valueType: ValueType): (MetaValue, Long) = + valueType match + case ValueType.UINT8 => + val buf = ByteBuffer.allocate(1) + channel.read(buf, offset) + (MetaUInt8(buf.get(0)), offset + 1) + + case ValueType.INT8 => + val buf = ByteBuffer.allocate(1) + channel.read(buf, offset) + (MetaInt8(buf.get(0)), offset + 1) + + case ValueType.UINT16 => + val buf = ByteBuffer.allocate(2).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buf, offset) + (MetaUInt16(buf.getShort(0)), offset + 2) + + case ValueType.INT16 => + val buf = ByteBuffer.allocate(2).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buf, offset) + (MetaInt16(buf.getShort(0)), offset + 2) + + case ValueType.UINT32 => + val buf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buf, offset) + (MetaUInt32(buf.getInt(0)), offset + 4) + + case ValueType.INT32 => + val buf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buf, offset) + (MetaInt32(buf.getInt(0)), offset + 4) + + case ValueType.FLOAT32 => + val buf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buf, offset) + (MetaFloat32(buf.getFloat(0)), offset + 4) + + case ValueType.BOOL => + val buf = ByteBuffer.allocate(1) + channel.read(buf, offset) + (MetaBool(buf.get(0) != 0), offset + 1) + + case ValueType.STRING => + val (str, endPos) = readString(channel, offset) + (MetaString(str), endPos) + + case ValueType.UINT64 => + val buf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buf, offset) + (MetaUInt64(buf.getLong(0)), offset + 8) + + case ValueType.INT64 => + val buf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buf, offset) + (MetaInt64(buf.getLong(0)), offset + 8) + + case ValueType.FLOAT64 => + val buf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + channel.read(buf, offset) + (MetaFloat64(buf.getDouble(0)), offset + 8) + + case ValueType.ARRAY => + var pos = offset + // Read array element type + val typeBuf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN) + channel.read(typeBuf, pos) + typeBuf.flip() + val elemTypeId = typeBuf.getInt + pos += 4 + + // Read array length + val lenBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + channel.read(lenBuf, pos) + lenBuf.flip() + val arrayLen = lenBuf.getLong + pos += 8 + + val elemType = ValueType.fromId(elemTypeId) + val values = mutable.ArrayBuffer[MetaValue]() + + for _ <- 0L until arrayLen do + val (value, endPos) = readValue(channel, pos, elemType) + values += value + pos = endPos + + (MetaArray(values.toSeq), pos) + + private def readString(channel: FileChannel, offset: Long): (String, Long) = + // Read string length (uint64) + val lenBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + channel.read(lenBuf, offset) + lenBuf.flip() + val strLen = lenBuf.getLong.toInt + + // Read string bytes + val strBuf = ByteBuffer.allocate(strLen) + channel.read(strBuf, offset + 8) + strBuf.flip() + val bytes = Array.ofDim[Byte](strLen) + strBuf.get(bytes) + + (new String(bytes, "UTF-8"), offset + 8 + strLen) + + private def readTensorInfo(channel: FileChannel, offset: Long): (TensorInfo, Long) = + var pos = offset + + // Read tensor name + val (name, nameEndPos) = readString(channel, pos) + pos = nameEndPos + + // Read number of dimensions + val dimBuf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN) + channel.read(dimBuf, pos) + dimBuf.flip() + val nDims = dimBuf.getInt + pos += 4 + + // Read dimensions + val shape = Array.ofDim[Long](nDims) + val shapeBuf = ByteBuffer.allocate(8 * nDims).order(ByteOrder.LITTLE_ENDIAN) + channel.read(shapeBuf, pos) + shapeBuf.flip() + for i <- 0 until nDims do + shape(i) = shapeBuf.getLong + pos += 8 * nDims + + // Read quant type + val qtBuf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN) + channel.read(qtBuf, pos) + qtBuf.flip() + val quantTypeId = qtBuf.getInt + pos += 4 + + // Read tensor data offset + val offsetBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + channel.read(offsetBuf, pos) + offsetBuf.flip() + val tensorOffset = offsetBuf.getLong + pos += 8 + + val quantType = QuantType.fromId(quantTypeId) + (TensorInfo(name, shape, quantType, tensorOffset), pos) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/inference/LlamaInference.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/inference/LlamaInference.scala new file mode 100644 index 00000000..2fda0c20 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/inference/LlamaInference.scala @@ -0,0 +1,133 @@ +package io.computenode.cyfra.llama.inference + +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.llama.gguf.GGUFReader +import io.computenode.cyfra.llama.gguf.GGUFReader.{QuantType, TensorInfo} +import io.computenode.cyfra.llama.model.{LlamaConfig, LlamaModel} +import io.computenode.cyfra.llama.pipeline.LlamaF16Pipeline +import io.computenode.cyfra.llama.util.Logger +import io.computenode.cyfra.runtime.VkCyfraRuntime + +/** Llama inference engine. + * + * Loads weights from GGUF and runs the forward pass on GPU. + * + * Supports two pipeline modes: + * - LlamaF32Pipeline: For quantized (Q4_K/Q6_K) models + * - LlamaF16Pipeline: For F16-native models (like Llama 3.2) + * + * @param model The loaded Llama model + * @param maxT Maximum sequence length for the pipeline + */ +class LlamaInference(model: LlamaModel, maxT: Int = 1)(using runtime: VkCyfraRuntime): + val config: LlamaConfig = model.config + + private lazy val allWeightsAreF16: Boolean = + val f32Tensors = model.gguf.tensors.filter(_.quantType == QuantType.F32) + val f16Tensors = model.gguf.tensors.filter(_.quantType == QuantType.F16) + Logger.debug(s"Model: ${f16Tensors.length} F16, ${f32Tensors.length} F32 tensors") + val hasF16Weights = f16Tensors.exists(t => t.name.contains("weight")) + val f32WeightMatrices = f32Tensors.filter(t => + t.name.contains("attn_q") || t.name.contains("attn_k") || t.name.contains("attn_v") || + t.name.contains("attn_output") || t.name.contains("ffn_gate") || t.name.contains("ffn_up") || + t.name.contains("ffn_down") || t.name.contains("token_embd") || t.name == "output.weight" + ) + hasF16Weights && f32WeightMatrices.isEmpty + + /** Check if weights are a mix of Q4_K and Q6_K (common in TinyLlama). */ + private lazy val hasMixedQuantization: Boolean = + val hasQ4K = model.gguf.tensors.exists(_.quantType == QuantType.Q4_K) + val hasQ6K = model.gguf.tensors.exists(_.quantType == QuantType.Q6_K) + hasQ4K && hasQ6K + + + // F16-Native Pipeline for F16 models (KV-cached only) + private lazy val f16Weights = if allWeightsAreF16 then Some(loadF16Weights()) else None + + // F16-Native KV Cached Pipeline with Vec4 optimizations (4x weight bandwidth!) + private lazy val f16Pipeline: LlamaF16Pipeline = + require(f16Weights.isDefined, "F16 KV pipeline requires F16 weights.") + LlamaF16Pipeline(f16Weights.get, config, maxT) + + /** Get the F16-native KV cached pipeline for efficient incremental inference. + * + * This pipeline uses KV caching for O(1) per-token inference: + * - Prefill: Process all prompt tokens at once + * - Decode: Process 1 token at a time, attend to full KV cache + * + * Uses Vec4-optimized matmuls for 4x weight memory bandwidth. + * Requires all weights to be F16 quantized. + * Requires dimensions (C, kvSize, FFN) to be divisible by 4. + */ + def getF16Pipeline: LlamaF16Pipeline = + require(f16Weights.isDefined, "F16 pipeline requires F16 weights. Check that model uses F16 quantization.") + f16Pipeline + + /** Read tensor as F16 bytes, converting F32 to F16 if needed. */ + private def readAsF16Bytes(tensor: TensorInfo): Array[Byte] = + if tensor.quantType == QuantType.F16 then + model.gguf.readTensorF16Bytes(tensor) + else if tensor.quantType == QuantType.F32 then + val f32Array = model.gguf.readTensorDequantized(tensor) + val f16Bytes = new Array[Byte](f32Array.length * 2) + val buf = java.nio.ByteBuffer.wrap(f16Bytes).order(java.nio.ByteOrder.LITTLE_ENDIAN) + for (f32Val, idx) <- f32Array.zipWithIndex do + val f16Bits = floatToFloat16Bits(f32Val) + buf.putShort(idx * 2, f16Bits.toShort) + f16Bytes + else + throw new IllegalArgumentException(s"Cannot convert ${tensor.quantType} to F16") + + /** Convert F32 to F16 bits (IEEE 754 half precision). */ + private def floatToFloat16Bits(value: Float): Int = + val bits = java.lang.Float.floatToRawIntBits(value) + val sign = (bits >> 31) & 0x1 + val exp = (bits >> 23) & 0xFF + val frac = bits & 0x7FFFFF + + if exp == 0xFF then + return (sign << 15) | 0x7C00 | (if frac != 0 then 1 else 0) + + if exp == 0 && frac == 0 then + return sign << 15 + + val f16Exp = math.max(0, math.min(31, exp - 127 + 15)) + val f16Frac = frac >> 13 + (sign << 15) | (f16Exp << 10) | f16Frac + + private def loadF16Weights(): LlamaF16Pipeline.F16ModelWeights = + Logger.info(s"Loading F16 weights (${config.numHiddenLayers} layers)...") + val startTime = System.currentTimeMillis() + + val tokenEmbed = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.tokenEmbed).get) + val outputNorm = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.outputNorm).get) + val output = model.getTensor(LlamaModel.TensorNames.output) match + case Some(tensor) => readAsF16Bytes(tensor) + case None => + Logger.debug("Using tied embeddings (no output.weight)") + tokenEmbed + + val layers = (0 until config.numHiddenLayers).map: l => + LlamaF16Pipeline.F16LayerWeights( + attnNorm = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.attnNorm(l)).get), + wq = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.attnQ(l)).get), + wk = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.attnK(l)).get), + wv = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.attnV(l)).get), + wo = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.attnOutput(l)).get), + ffnNorm = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.ffnNorm(l)).get), + ffnGate = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.ffnGate(l)).get), + ffnUp = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.ffnUp(l)).get), + ffnDown = readAsF16Bytes(model.getTensor(LlamaModel.TensorNames.ffnDown(l)).get), + ) + + val elapsed = System.currentTimeMillis() - startTime + val totalMB = (tokenEmbed.length + outputNorm.length + output.length + + layers.map(l => l.attnNorm.length + l.wq.length + l.wk.length + l.wv.length + l.wo.length + + l.ffnNorm.length + l.ffnGate.length + l.ffnUp.length + l.ffnDown.length).sum) / 1024 / 1024 + Logger.info(s"F16 weights loaded: ${elapsed}ms, ${totalMB}MB") + + LlamaF16Pipeline.F16ModelWeights(tokenEmbed, layers, outputNorm, output) + +object LlamaInference: + def apply(model: LlamaModel)(using runtime: VkCyfraRuntime): LlamaInference = + new LlamaInference(model) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/model/LlamaConfig.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/model/LlamaConfig.scala new file mode 100644 index 00000000..ff80c9bb --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/model/LlamaConfig.scala @@ -0,0 +1,119 @@ +package io.computenode.cyfra.llama.model + +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.struct.GStruct + +/** Llama model configuration. + * + * Based on the Llama 2 / Llama 3 architecture with: + * - RMSNorm instead of LayerNorm + * - SiLU activation in MLP + * - Rotary Position Embeddings (RoPE) + * - Grouped Query Attention (GQA) + * - SwiGLU MLP structure + */ +case class LlamaConfig( + hiddenSize: Int, // Model dimension (d_model) + intermediateSize: Int, // MLP hidden dimension (usually ~2.7x hidden) + numAttentionHeads: Int, // Query heads + numKeyValueHeads: Int, // Key/Value heads (for GQA) + numHiddenLayers: Int, // Number of transformer blocks + vocabSize: Int, // Vocabulary size + maxPositionEmbeddings: Int, // Max context length + rmsNormEps: Float = 1e-6f, // RMSNorm epsilon + ropeTheta: Float = 10000.0f, // RoPE base frequency + bos_token_id: Int = 1, // Beginning of sequence token + eos_token_id: Int = 2, // End of sequence token +): + def headSize: Int = hiddenSize / numAttentionHeads + def kvHeadSize: Int = hiddenSize / numKeyValueHeads + def gqaRatio: Int = numAttentionHeads / numKeyValueHeads // GQA ratio (1 = MHA, >1 = GQA) + + /** Total parameter count estimate (weights only, no embeddings counted separately) */ + def numParameters: Long = + // Embeddings + val embedParams = vocabSize.toLong * hiddenSize + // Per-layer params + val qkvParams = hiddenSize * (hiddenSize + 2 * (hiddenSize * numKeyValueHeads / numAttentionHeads)) + val outputParams = hiddenSize * hiddenSize + val mlpParams = 3 * hiddenSize * intermediateSize // gate, up, down projections + val normParams = 2 * hiddenSize // 2 RMSNorms per layer + val perLayerParams = qkvParams + outputParams + mlpParams + normParams + // Total + embedParams + numHiddenLayers * perLayerParams + hiddenSize + embedParams // final norm + output proj + +object LlamaConfig: + /** TinyLlama 1.1B configuration */ + val TinyLlama_1B: LlamaConfig = LlamaConfig( + hiddenSize = 2048, + intermediateSize = 5632, + numAttentionHeads = 32, + numKeyValueHeads = 4, + numHiddenLayers = 22, + vocabSize = 32000, + maxPositionEmbeddings = 2048, + ) + + /** Llama 2 7B configuration */ + val Llama2_7B: LlamaConfig = LlamaConfig( + hiddenSize = 4096, + intermediateSize = 11008, + numAttentionHeads = 32, + numKeyValueHeads = 32, // MHA (not GQA) + numHiddenLayers = 32, + vocabSize = 32000, + maxPositionEmbeddings = 4096, + ) + + /** Llama 2 13B configuration */ + val Llama2_13B: LlamaConfig = LlamaConfig( + hiddenSize = 5120, + intermediateSize = 13824, + numAttentionHeads = 40, + numKeyValueHeads = 40, + numHiddenLayers = 40, + vocabSize = 32000, + maxPositionEmbeddings = 4096, + ) + + /** Llama 3 8B configuration */ + val Llama3_8B: LlamaConfig = LlamaConfig( + hiddenSize = 4096, + intermediateSize = 14336, + numAttentionHeads = 32, + numKeyValueHeads = 8, // GQA with ratio 4 + numHiddenLayers = 32, + vocabSize = 128256, + maxPositionEmbeddings = 8192, + ropeTheta = 500000.0f, + ) + +/** GPU-side parameters for Llama operations */ +case class LlamaParams( + B: Int32, // Batch size + T: Int32, // Sequence length (current position for generation) + C: Int32, // Hidden size (channels) + NH: Int32, // Number of attention heads + NKV: Int32, // Number of key-value heads + HS: Int32, // Head size + eps: Float32, // RMSNorm epsilon +) extends GStruct[LlamaParams] + +/** GPU-side parameters for RoPE */ +case class RoPEParams( + headSize: Int32, + maxSeqLen: Int32, + theta: Float32, + position: Int32, // Current position in sequence +) extends GStruct[RoPEParams] + +/** GPU-side parameters for flash attention */ +case class FlashAttnParams( + B: Int32, // Batch size + T: Int32, // Current sequence length + maxT: Int32, // Maximum sequence length (for KV cache) + NH: Int32, // Number of query heads + NKV: Int32, // Number of KV heads + HS: Int32, // Head size + scale: Float32, // 1/sqrt(head_size) +) extends GStruct[FlashAttnParams] diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/model/LlamaModel.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/model/LlamaModel.scala new file mode 100644 index 00000000..6c34cc7c --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/model/LlamaModel.scala @@ -0,0 +1,89 @@ +package io.computenode.cyfra.llama.model + +import io.computenode.cyfra.llama.gguf.GGUFReader +import io.computenode.cyfra.llama.gguf.GGUFReader.* +import io.computenode.cyfra.llama.util.Logger +import java.nio.file.Path + +/** Llama model loaded from GGUF file. + * + * Contains model configuration and weight tensors. + */ +case class LlamaModel( + config: LlamaConfig, + gguf: GGUFFile, +): + /** Get a weight tensor by name. */ + def getTensor(name: String): Option[TensorInfo] = gguf.getTensor(name) + + /** Read a weight tensor as Float32 array (only for F32 tensors). */ + def readWeightF32(name: String): Array[Float] = + gguf.getTensor(name) match + case Some(tensor) => gguf.readTensorF32(tensor) + case None => throw new IllegalArgumentException(s"Tensor not found: $name") + + /** Read raw tensor bytes (for quantized tensors). */ + def readWeightBytes(name: String): Array[Byte] = + gguf.getTensor(name) match + case Some(tensor) => gguf.readTensorBytes(tensor) + case None => throw new IllegalArgumentException(s"Tensor not found: $name") + + /** Close the underlying file. */ + def close(): Unit = gguf.close() + + /** List all tensor names in the model. */ + def tensorNames: Seq[String] = gguf.tensors.map(_.name) + + /** Get model architecture name. */ + def architecture: String = gguf.getString("general.architecture").getOrElse("unknown") + + /** Get model name. */ + def name: String = gguf.getString("general.name").getOrElse("unknown") + + /** Log model info at INFO level. */ + def logInfo(): Unit = + Logger.info(s"Model: $name, arch=$architecture, ${gguf.tensors.size} tensors") + Logger.info(s"Config: ${config.hiddenSize}d, ${config.numHiddenLayers}L, ${config.numAttentionHeads}H, vocab=${config.vocabSize}") + +object LlamaModel: + /** Load Llama model from GGUF file. + * + * Extracts model configuration from GGUF metadata. + */ + def fromGGUF(path: Path): LlamaModel = + val gguf = GGUFReader.read(path) + + // Extract architecture-specific metadata prefix + val arch = gguf.getString("general.architecture").getOrElse("llama") + + // Extract model configuration from metadata + val config = LlamaConfig( + hiddenSize = gguf.getInt(s"$arch.embedding_length").getOrElse(4096), + intermediateSize = gguf.getInt(s"$arch.feed_forward_length").getOrElse(11008), + numAttentionHeads = gguf.getInt(s"$arch.attention.head_count").getOrElse(32), + numKeyValueHeads = gguf.getInt(s"$arch.attention.head_count_kv").getOrElse(32), + numHiddenLayers = gguf.getInt(s"$arch.block_count").getOrElse(32), + vocabSize = gguf.getInt(s"$arch.vocab_size").getOrElse(32000), + maxPositionEmbeddings = gguf.getInt(s"$arch.context_length").getOrElse(2048), + rmsNormEps = gguf.getFloat(s"$arch.attention.layer_norm_rms_epsilon").getOrElse(1e-6f), + ropeTheta = gguf.getFloat(s"$arch.rope.freq_base").getOrElse(10000.0f), + ) + + LlamaModel(config, gguf) + + /** Common Llama tensor name patterns. */ + object TensorNames: + def tokenEmbed: String = "token_embd.weight" + def outputNorm: String = "output_norm.weight" + def output: String = "output.weight" + + def attnNorm(layer: Int): String = s"blk.$layer.attn_norm.weight" + def attnQ(layer: Int): String = s"blk.$layer.attn_q.weight" + def attnK(layer: Int): String = s"blk.$layer.attn_k.weight" + def attnV(layer: Int): String = s"blk.$layer.attn_v.weight" + def attnOutput(layer: Int): String = s"blk.$layer.attn_output.weight" + + def ffnNorm(layer: Int): String = s"blk.$layer.ffn_norm.weight" + def ffnGate(layer: Int): String = s"blk.$layer.ffn_gate.weight" + def ffnUp(layer: Int): String = s"blk.$layer.ffn_up.weight" + def ffnDown(layer: Int): String = s"blk.$layer.ffn_down.weight" diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaF16Pipeline.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaF16Pipeline.scala new file mode 100644 index 00000000..b06456a7 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaF16Pipeline.scala @@ -0,0 +1,660 @@ +package io.computenode.cyfra.llama.pipeline + +import io.computenode.cyfra.core.{CyfraRuntime, GBufferRegion, GExecution} +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.llama.model.LlamaConfig +import io.computenode.cyfra.llama.pipeline.LlamaF16Pipeline.* +import io.computenode.cyfra.llama.pipeline.PipelineUtils.* +import io.computenode.cyfra.llama.programs.* +import io.computenode.cyfra.llama.programs.f16.* +import io.computenode.cyfra.llama.util.Logger +import io.computenode.cyfra.utility.NVTX + +import java.nio.{ByteBuffer, ByteOrder} + +/** F16-Native Llama GPU Pipeline for fast incremental inference. + * + * All compute in half precision for maximum memory efficiency: + * - F16 weights loaded directly from GGUF (no conversion) + * - F16 compute throughout (matmul, attention, FFN) + * - Only final logits in F32 (for softmax stability) + * - 2x memory savings vs F32 pipeline + * + * KV-cached inference: + * - Prefill: Process all prompt tokens at once, fill KV cache + * - Decode: Process 1 token at a time, attend to full cache + * - O(1) complexity per generated token + */ +case class LlamaF16Pipeline( + weights: F16ModelWeights, + config: LlamaConfig, + maxSeqLen: Int = DefaultMaxSeqLen, + B: Int = 1, +)(using runtime: CyfraRuntime) extends LlamaPipeline: + + private val C = config.hiddenSize + private val V = config.vocabSize + private val L = config.numHiddenLayers + private val NH = config.numAttentionHeads + private val NKV = config.numKeyValueHeads + private val headSize = config.headSize + private val FFN = config.intermediateSize + private val kvSize = NKV * headSize + + Logger.info(s"Uploading F16 weights: $L layers, ${V}×${C} vocab, maxSeqLen=$maxSeqLen") + + private val tokenEmbedBuf = allocateF16Buffer(V * C) + copyF16BytesToBuffer(weights.tokenEmbed, tokenEmbedBuf) + tokenEmbedBuf.rewind() + + private val attnNormBuf = allocateF16Buffer(L * C) + private val wqBuf = allocateF16Buffer(L * C * C) + private val wkBuf = allocateF16Buffer(L * C * kvSize) + private val wvBuf = allocateF16Buffer(L * C * kvSize) + private val woBuf = allocateF16Buffer(L * C * C) + private val ffnNormBuf = allocateF16Buffer(L * C) + private val ffnGateBuf = allocateF16Buffer(L * FFN * C) + private val ffnUpBuf = allocateF16Buffer(L * FFN * C) + private val ffnDownBuf = allocateF16Buffer(L * C * FFN) + + for (layer, layerIdx) <- weights.layers.zipWithIndex do + copyF16BytesToBuffer(layer.attnNorm, attnNormBuf, layerIdx * C * 2) + copyF16BytesToBuffer(layer.wq, wqBuf, layerIdx * C * C * 2) + copyF16BytesToBuffer(layer.wk, wkBuf, layerIdx * C * kvSize * 2) + copyF16BytesToBuffer(layer.wv, wvBuf, layerIdx * C * kvSize * 2) + copyF16BytesToBuffer(layer.wo, woBuf, layerIdx * C * C * 2) + copyF16BytesToBuffer(layer.ffnNorm, ffnNormBuf, layerIdx * C * 2) + copyF16BytesToBuffer(layer.ffnGate, ffnGateBuf, layerIdx * FFN * C * 2) + copyF16BytesToBuffer(layer.ffnUp, ffnUpBuf, layerIdx * FFN * C * 2) + copyF16BytesToBuffer(layer.ffnDown, ffnDownBuf, layerIdx * C * FFN * 2) + + attnNormBuf.rewind(); wqBuf.rewind(); wkBuf.rewind(); wvBuf.rewind(); woBuf.rewind() + ffnNormBuf.rewind(); ffnGateBuf.rewind(); ffnUpBuf.rewind(); ffnDownBuf.rewind() + + private val outputNormBuf = allocateF16Buffer(C) + copyF16BytesToBuffer(weights.outputNorm, outputNormBuf) + outputNormBuf.rewind() + + private val outputWeightBuf = allocateF16Buffer(V * C) + copyF16BytesToBuffer(weights.output, outputWeightBuf) + outputWeightBuf.rewind() + + Logger.info("F16 weights uploaded to GPU") + + private val decodeTokenBuf = allocateIntBuffer(B * 1) + private val decodeLogitsBuf = allocateF32Buffer(B * 1 * V) + private val prefillLogitsArr = new Array[Float](V) + private val attnParamsBuf = ByteBuffer.allocateDirect(8).order(ByteOrder.nativeOrder()) + + // GPU sampling buffers + private val sampleParamsBuf = ByteBuffer.allocateDirect(16).order(ByteOrder.nativeOrder()) // std140 aligned + private val sampledTokenBuf = allocateIntBuffer(1) + private val random = new scala.util.Random() + + // CPU sampling for prefill (can't use decode pipeline for variable-length prefill) + private def cpuSample(logits: Array[Float], temperature: Float, topP: Float): Int = + if temperature < 0.001f then + var maxIdx = 0 + var maxVal = logits(0) + var i = 1 + while i < logits.length do + if logits(i) > maxVal then + maxVal = logits(i) + maxIdx = i + i += 1 + maxIdx + else + val scaled = logits.map(_ / temperature) + val maxLogit = scaled.max + val expLogits = scaled.map(x => math.exp(x - maxLogit).toFloat) + val sumExp = expLogits.sum + val probs = expLogits.map(_ / sumExp) + val indexed = probs.zipWithIndex.sortBy(-_._1) + var cumSum = 0.0f + var cutoffIdx = 0 + while cutoffIdx < indexed.length && cumSum < topP do + cumSum += indexed(cutoffIdx)._1 + cutoffIdx += 1 + val topTokens = indexed.take(cutoffIdx) + val topSum = topTokens.map(_._1).sum + val threshold = random.nextFloat() * topSum + var acc = 0.0f + var result = topTokens.last._2 + for (prob, idx) <- topTokens do + acc += prob + if acc >= threshold && result == topTokens.last._2 then + result = idx + result + + private val pipelineCache = scala.collection.mutable.Map[(Int, Int, Boolean), GExecution[PipelineParams, PipelineLayout, PipelineLayout]]() + + private def getOrBuildPipeline(T: Int, seqLen: Int, withSampling: Boolean = false): GExecution[PipelineParams, PipelineLayout, PipelineLayout] = + pipelineCache.getOrElseUpdate((T, seqLen, withSampling), buildPipeline(config, B, T, maxSeqLen, withSampling)) + + // Decode pipeline with GPU sampling appended + + private var currentSeqLen: Int = 0 + + def seqLen: Int = currentSeqLen + + private var _lastStats: GenerationStats = null + def lastStats: Option[GenerationStats] = Option(_lastStats) + + def generate( + promptTokens: Array[Int], + maxNewTokens: Int, + temperature: Float = 0.7f, + topP: Float = 0.9f, + onToken: Int => Unit = _ => (), + stopTokens: Set[Int] = Set.empty, + reportStats: Boolean = false, + ): Array[Int] = + require( + promptTokens.length + maxNewTokens <= maxSeqLen, + s"Total sequence ${promptTokens.length + maxNewTokens} exceeds maxSeqLen=$maxSeqLen", + ) + + currentSeqLen = 0 + val generatedTokens = scala.collection.mutable.ArrayBuffer[Int]() + val prefillT = promptTokens.length + + tokenEmbedBuf.rewind() + attnNormBuf.rewind() + wqBuf.rewind() + wkBuf.rewind() + wvBuf.rewind() + woBuf.rewind() + ffnNormBuf.rewind() + ffnGateBuf.rewind() + ffnUpBuf.rewind() + ffnDownBuf.rewind() + outputNormBuf.rewind() + outputWeightBuf.rewind() + + val prefillTokensBuf = allocateIntBuffer(B * prefillT) + prefillTokensBuf.asIntBuffer().put(promptTokens) + prefillTokensBuf.rewind() + val prefillLogitsBuf = allocateF32Buffer(B * prefillT * V) + + val prefillPipeline = getOrBuildPipeline(prefillT, prefillT, withSampling = false) + val decodePipeline = getOrBuildPipeline(1, maxSeqLen, withSampling = true) + val prefillParams = PipelineParams(config, B, prefillT, 0) + + val prefillAttnBuf = ByteBuffer.allocateDirect(8).order(ByteOrder.nativeOrder()) + prefillAttnBuf.putInt(prefillT) + prefillAttnBuf.putInt(0) + prefillAttnBuf.flip() + + val decodeAttnBuf = ByteBuffer.allocateDirect(8).order(ByteOrder.nativeOrder()) + decodeAttnBuf.putInt(prefillT + 1) + decodeAttnBuf.putInt(prefillT) + decodeAttnBuf.flip() + + var prefillStartNs = 0L + var prefillEndNs = 0L + var decodeTimeNs = 0L + var shouldStop = false + + val afterPrefill = GBufferRegion + .allocate[GenerationLayout] + .map: layout => + NVTX.push(s"Prefill[$prefillT]") + prefillStartNs = System.nanoTime() + prefillPipeline.execute(prefillParams, layout.toPipelineLayout) + layout + .map: layout => + layout.prefillLogits.read(prefillLogitsBuf) + prefillEndNs = System.nanoTime() + NVTX.pop() + prefillLogitsBuf.rewind() + + // Extract last position logits for sampling + val prefillLogitsArr = new Array[Float](prefillT * V) + copyFromF32Buffer(prefillLogitsBuf, prefillLogitsArr) + val lastPosLogits = prefillLogitsArr.slice((prefillT - 1) * V, prefillT * V) + + val firstToken = cpuSample(lastPosLogits, temperature, topP) + generatedTokens += firstToken + onToken(firstToken) + currentSeqLen = prefillT + + if stopTokens.contains(firstToken) then shouldStop = true + else + decodeTokenBuf.clear() + decodeTokenBuf.asIntBuffer().put(Array(firstToken)) + decodeTokenBuf.rewind() + layout.decodeToken.write(decodeTokenBuf) + + layout + + val afterDecode = (0 until maxNewTokens - 1).foldLeft(afterPrefill): (region, stepIdx) => + val seqLen = prefillT + stepIdx + 1 + val startPos = seqLen - 1 + val decodeParams = PipelineParams(config, B, 1, startPos) + + region.map: layout => + if shouldStop then layout + else + attnParamsBuf.clear() + attnParamsBuf.putInt(seqLen) + attnParamsBuf.putInt(startPos) + attnParamsBuf.flip() + layout.decodeAttnParams.asInstanceOf[io.computenode.cyfra.dsl.binding.GBinding[AttentionParams]].write(attnParamsBuf, 0) + + // Set sampling params + sampleParamsBuf.clear() + sampleParamsBuf.putFloat(temperature) + sampleParamsBuf.putFloat(topP) + sampleParamsBuf.putFloat(random.nextFloat()) + sampleParamsBuf.putFloat(0.0f) + sampleParamsBuf.rewind() + layout.sampleParams.asInstanceOf[io.computenode.cyfra.dsl.binding.GBinding[F16TopPSampleProgram.SampleParams]].write(sampleParamsBuf, 0) + + NVTX.push(s"Decode[$stepIdx]") + val stepStartNs = System.nanoTime() + NVTX.push(s"Execute[$stepIdx]") + decodePipeline.execute(decodeParams, layout.toDecodeLayout) + NVTX.pop() + + // Read sampled token + layout.sampledToken.read(sampledTokenBuf) + decodeTimeNs += (System.nanoTime() - stepStartNs) + NVTX.pop() + sampledTokenBuf.rewind() + val nextToken = sampledTokenBuf.asIntBuffer().get(0) + generatedTokens += nextToken + onToken(nextToken) + currentSeqLen += 1 + + if stopTokens.contains(nextToken) then shouldStop = true + else + decodeTokenBuf.clear() + decodeTokenBuf.asIntBuffer().put(Array(nextToken)) + decodeTokenBuf.rewind() + layout.decodeToken.write(decodeTokenBuf) + + layout + + afterDecode.runUnsafe( + init = GenerationLayout( + tokenEmbed = GBuffer[Float16](tokenEmbedBuf), + attnNorm = GBuffer[Float16](attnNormBuf), + ffnNorm = GBuffer[Float16](ffnNormBuf), + outputNorm = GBuffer[Float16](outputNormBuf), + wq = GBuffer[Vec4[Float16]](wqBuf), + wk = GBuffer[Vec4[Float16]](wkBuf), + wv = GBuffer[Vec4[Float16]](wvBuf), + wo = GBuffer[Vec4[Float16]](woBuf), + ffnGate = GBuffer[Vec4[Float16]](ffnGateBuf), + ffnUp = GBuffer[Vec4[Float16]](ffnUpBuf), + ffnDown = GBuffer[Vec4[Float16]](ffnDownBuf), + outputWeight = GBuffer[Vec4[Float16]](outputWeightBuf), + kCache = GBuffer[Float16](L * maxSeqLen * kvSize), + vCache = GBuffer[Float16](L * maxSeqLen * kvSize), + prefillTokens = GBuffer[Int32](prefillTokensBuf), + prefillHidden = GBuffer[Float16](B * prefillT * C), + prefillResidual = GBuffer[Float16](B * prefillT * C), + prefillAttnNormOut = GBuffer[Float16](B * prefillT * C), + prefillQ = GBuffer[Float16](B * prefillT * C), + prefillK = GBuffer[Float16](B * prefillT * kvSize), + prefillV = GBuffer[Float16](B * prefillT * kvSize), + prefillQRoped = GBuffer[Float16](B * prefillT * C), + prefillKRoped = GBuffer[Float16](B * prefillT * kvSize), + prefillAttnScores = GBuffer[Float32](B * prefillT * NH * maxSeqLen), + prefillAttnOut = GBuffer[Float16](B * prefillT * C), + prefillFfnNormOut = GBuffer[Float16](B * prefillT * C), + prefillGate = GBuffer[Float16](B * prefillT * FFN), + prefillUp = GBuffer[Float16](B * prefillT * FFN), + prefillFfnHidden = GBuffer[Float16](B * prefillT * FFN), + prefillFfnOut = GBuffer[Float16](B * prefillT * C), + prefillLogits = GBuffer[Float32](prefillLogitsBuf), + prefillAttnParams = GUniform[AttentionParams](prefillAttnBuf), + decodeToken = GBuffer[Int32](decodeTokenBuf), + decodeHidden = GBuffer[Float16](B * 1 * C), + decodeResidual = GBuffer[Float16](B * 1 * C), + decodeAttnNormOut = GBuffer[Float16](B * 1 * C), + decodeQ = GBuffer[Float16](B * 1 * C), + decodeK = GBuffer[Float16](B * 1 * kvSize), + decodeV = GBuffer[Float16](B * 1 * kvSize), + decodeQRoped = GBuffer[Float16](B * 1 * C), + decodeKRoped = GBuffer[Float16](B * 1 * kvSize), + decodeAttnScores = GBuffer[Float32](B * 1 * NH * maxSeqLen), + decodeAttnOut = GBuffer[Float16](B * 1 * C), + decodeFfnNormOut = GBuffer[Float16](B * 1 * C), + decodeGate = GBuffer[Float16](B * 1 * FFN), + decodeUp = GBuffer[Float16](B * 1 * FFN), + decodeFfnHidden = GBuffer[Float16](B * 1 * FFN), + decodeFfnOut = GBuffer[Float16](B * 1 * C), + decodeLogits = GBuffer[Float32](decodeLogitsBuf), + decodeAttnParams = GUniform[AttentionParams](decodeAttnBuf), + sampleParams = GUniform[F16TopPSampleProgram.SampleParams](sampleParamsBuf), + sampledToken = GBuffer[Int32](sampledTokenBuf), + ), + onDone = _ => (), + ) + + val prefillTimeMs = (prefillEndNs - prefillStartNs) / 1_000_000.0 + val decodeTimeMs = decodeTimeNs / 1_000_000.0 + val totalTimeMs = prefillTimeMs + decodeTimeMs + + _lastStats = GenerationStats( + promptTokens = prefillT, + generatedTokens = generatedTokens.length, + prefillTimeMs = prefillTimeMs, + decodeTimeMs = decodeTimeMs, + totalTimeMs = totalTimeMs, + ) + + if reportStats then Logger.info(_lastStats.toString) + + generatedTokens.toArray + end generate + +end LlamaF16Pipeline + +object LlamaF16Pipeline: + + val DefaultMaxSeqLen = 2048 + + case class PipelineParams( + config: LlamaConfig, + B: Int, + T: Int, + startPos: Int = 0, + ): + def C: Int = config.hiddenSize + def NH: Int = config.numAttentionHeads + def NKV: Int = config.numKeyValueHeads + def headSize: Int = config.headSize + def FFN: Int = config.intermediateSize + def V: Int = config.vocabSize + def L: Int = config.numHiddenLayers + def kvSize: Int = NKV * headSize + + case class F16LayerWeights( + attnNorm: Array[Byte], + wq: Array[Byte], + wk: Array[Byte], + wv: Array[Byte], + wo: Array[Byte], + ffnNorm: Array[Byte], + ffnGate: Array[Byte], + ffnUp: Array[Byte], + ffnDown: Array[Byte], + ) + + case class F16ModelWeights( + tokenEmbed: Array[Byte], + layers: Seq[F16LayerWeights], + outputNorm: Array[Byte], + output: Array[Byte], + ) + + case class PipelineLayout( + tokens: GBuffer[Int32], + tokenEmbed: GBuffer[Float16], + attnNorm: GBuffer[Float16], + ffnNorm: GBuffer[Float16], + outputNorm: GBuffer[Float16], + wq: GBuffer[Vec4[Float16]], + wk: GBuffer[Vec4[Float16]], + wv: GBuffer[Vec4[Float16]], + wo: GBuffer[Vec4[Float16]], + ffnGate: GBuffer[Vec4[Float16]], + ffnUp: GBuffer[Vec4[Float16]], + ffnDown: GBuffer[Vec4[Float16]], + outputWeight: GBuffer[Vec4[Float16]], + kCache: GBuffer[Float16], + vCache: GBuffer[Float16], + hidden: GBuffer[Float16], + residual: GBuffer[Float16], + attnNormOut: GBuffer[Float16], + q: GBuffer[Float16], + k: GBuffer[Float16], + v: GBuffer[Float16], + qRoped: GBuffer[Float16], + kRoped: GBuffer[Float16], + attnScores: GBuffer[Float32], + attnOut: GBuffer[Float16], + ffnNormOut: GBuffer[Float16], + gate: GBuffer[Float16], + up: GBuffer[Float16], + ffnHidden: GBuffer[Float16], + ffnOut: GBuffer[Float16], + logits: GBuffer[Float32], + attnParams: GUniform[AttentionParams], + sampleParams: GUniform[F16TopPSampleProgram.SampleParams], + sampledToken: GBuffer[Int32], + ) derives Layout + + case class GenerationLayout( + tokenEmbed: GBuffer[Float16], + attnNorm: GBuffer[Float16], + ffnNorm: GBuffer[Float16], + outputNorm: GBuffer[Float16], + wq: GBuffer[Vec4[Float16]], + wk: GBuffer[Vec4[Float16]], + wv: GBuffer[Vec4[Float16]], + wo: GBuffer[Vec4[Float16]], + ffnGate: GBuffer[Vec4[Float16]], + ffnUp: GBuffer[Vec4[Float16]], + ffnDown: GBuffer[Vec4[Float16]], + outputWeight: GBuffer[Vec4[Float16]], + kCache: GBuffer[Float16], + vCache: GBuffer[Float16], + prefillTokens: GBuffer[Int32], + prefillHidden: GBuffer[Float16], + prefillResidual: GBuffer[Float16], + prefillAttnNormOut: GBuffer[Float16], + prefillQ: GBuffer[Float16], + prefillK: GBuffer[Float16], + prefillV: GBuffer[Float16], + prefillQRoped: GBuffer[Float16], + prefillKRoped: GBuffer[Float16], + prefillAttnScores: GBuffer[Float32], + prefillAttnOut: GBuffer[Float16], + prefillFfnNormOut: GBuffer[Float16], + prefillGate: GBuffer[Float16], + prefillUp: GBuffer[Float16], + prefillFfnHidden: GBuffer[Float16], + prefillFfnOut: GBuffer[Float16], + prefillLogits: GBuffer[Float32], + prefillAttnParams: GUniform[AttentionParams], + decodeToken: GBuffer[Int32], + decodeHidden: GBuffer[Float16], + decodeResidual: GBuffer[Float16], + decodeAttnNormOut: GBuffer[Float16], + decodeQ: GBuffer[Float16], + decodeK: GBuffer[Float16], + decodeV: GBuffer[Float16], + decodeQRoped: GBuffer[Float16], + decodeKRoped: GBuffer[Float16], + decodeAttnScores: GBuffer[Float32], + decodeAttnOut: GBuffer[Float16], + decodeFfnNormOut: GBuffer[Float16], + decodeGate: GBuffer[Float16], + decodeUp: GBuffer[Float16], + decodeFfnHidden: GBuffer[Float16], + decodeFfnOut: GBuffer[Float16], + decodeLogits: GBuffer[Float32], + decodeAttnParams: GUniform[AttentionParams], + // Sampling buffers + sampleParams: GUniform[F16TopPSampleProgram.SampleParams], + sampledToken: GBuffer[Int32], + ) derives Layout: + + def toPipelineLayout: PipelineLayout = PipelineLayout( + tokens = prefillTokens, + tokenEmbed = tokenEmbed, attnNorm = attnNorm, ffnNorm = ffnNorm, outputNorm = outputNorm, + wq = wq, wk = wk, wv = wv, wo = wo, + ffnGate = ffnGate, ffnUp = ffnUp, ffnDown = ffnDown, outputWeight = outputWeight, + kCache = kCache, vCache = vCache, + hidden = prefillHidden, residual = prefillResidual, attnNormOut = prefillAttnNormOut, + q = prefillQ, k = prefillK, v = prefillV, qRoped = prefillQRoped, kRoped = prefillKRoped, + attnScores = prefillAttnScores, attnOut = prefillAttnOut, ffnNormOut = prefillFfnNormOut, + gate = prefillGate, up = prefillUp, ffnHidden = prefillFfnHidden, ffnOut = prefillFfnOut, + logits = prefillLogits, attnParams = prefillAttnParams, sampleParams = sampleParams, sampledToken = sampledToken, + ) + + def toDecodeLayout: PipelineLayout = PipelineLayout( + tokens = decodeToken, + tokenEmbed = tokenEmbed, attnNorm = attnNorm, ffnNorm = ffnNorm, outputNorm = outputNorm, + wq = wq, wk = wk, wv = wv, wo = wo, + ffnGate = ffnGate, ffnUp = ffnUp, ffnDown = ffnDown, outputWeight = outputWeight, + kCache = kCache, vCache = vCache, + hidden = decodeHidden, residual = decodeResidual, attnNormOut = decodeAttnNormOut, + q = decodeQ, k = decodeK, v = decodeV, qRoped = decodeQRoped, kRoped = decodeKRoped, + attnScores = decodeAttnScores, attnOut = decodeAttnOut, ffnNormOut = decodeFfnNormOut, + gate = decodeGate, up = decodeUp, ffnHidden = decodeFfnHidden, ffnOut = decodeFfnOut, + logits = decodeLogits, attnParams = decodeAttnParams, sampleParams = sampleParams, sampledToken = sampledToken, + ) + + def buildPipeline( + config: LlamaConfig, + B: Int, + T: Int, + maxSeqLen: Int, + withSampling: Boolean = false, + ): GExecution[PipelineParams, PipelineLayout, PipelineLayout] = + val C = config.hiddenSize + val NH = config.numAttentionHeads + val NKV = config.numKeyValueHeads + val headSize = config.headSize + val FFN = config.intermediateSize + val V = config.vocabSize + val L = config.numHiddenLayers + val kvSize = NKV * headSize + val eps = config.rmsNormEps.toFloat + val theta = config.ropeTheta.toFloat + val startPos = maxSeqLen - T + val copySizeBytes = B * T * C * 2 + + require(C % 4 == 0, s"hiddenSize ($C) must be divisible by 4") + require(kvSize % 4 == 0, s"kvSize ($kvSize) must be divisible by 4") + require(FFN % 4 == 0, s"intermediateSize ($FFN) must be divisible by 4") + + val embSizes = F16EmbeddingProgram.Sizes(B * T, C, V) + val finalNormSizes = F16RMSNormProgram.Sizes(B * T, C, eps, 0, C) + val logitsSizes = F16OutputVec4Program.Sizes(B * T, C, V) + + val afterEmbedding = GExecution[PipelineParams, PipelineLayout]() + .addProgram(F16EmbeddingProgram.forward(embSizes))( + _ => embSizes, + l => F16EmbeddingProgram.ProgramLayout(l.tokens, l.tokenEmbed, l.hidden), + ) + + val afterLayers = (0 until L).foldLeft(afterEmbedding): (pipeline, layer) => + val normOffset = layer * C + val wqOffsetVec4 = layer * C * (C / 4) + val wkOffsetVec4 = layer * C * (kvSize / 4) + val wvOffsetVec4 = layer * C * (kvSize / 4) + val woOffsetVec4 = layer * C * (C / 4) + val ffnGateOffsetVec4 = layer * FFN * (C / 4) + val ffnUpOffsetVec4 = layer * FFN * (C / 4) + val ffnDownOffsetVec4 = layer * C * (FFN / 4) + val kvCacheLayerOffset = layer * maxSeqLen * kvSize + + val attnNormSizes = F16RMSNormProgram.Sizes(B * T, C, eps, normOffset, L * C) + val qkvSizes = F16FusedQKVMatmulProgram.Sizes( + batchSize = B * T, inFeatures = C, qOutFeatures = C, kvOutFeatures = kvSize, + wqOffsetVec4 = wqOffsetVec4, wkOffsetVec4 = wkOffsetVec4, wvOffsetVec4 = wvOffsetVec4, + totalWqVec4 = L * C * (C / 4), totalWkVec4 = L * C * (kvSize / 4), totalWvVec4 = L * C * (kvSize / 4), + ) + val fusedRopeSizes = F16FusedRoPEProgram.Sizes(B, T, NH, NKV, headSize, theta) + val fusedKVWriteSizes = F16FusedKVCacheWriteProgram.Sizes( + B, T, NKV, headSize, maxSeqLen, layer, startPos, kvCacheLayerOffset, kvCacheLayerOffset, L, + ) + val attnScoresSizes = F16AttentionScoresProgram.Sizes(B, T, NH, NKV, headSize, maxSeqLen, kvCacheLayerOffset, L) + val attnSoftmaxSizes = F16AttentionSoftmaxProgram.Sizes(B, T, NH, maxSeqLen) + val attnOutputSizes = F16AttentionOutputProgram.Sizes(B, T, NH, NKV, headSize, maxSeqLen, kvCacheLayerOffset, L) + val woSizes = F16MatmulVecHybridProgram.Sizes(B * T, C, C, woOffsetVec4, L * C * (C / 4)) + val resSizes = F16ResidualAddProgram.Sizes(B * T * C) + val ffnNormSizes = F16RMSNormProgram.Sizes(B * T, C, eps, normOffset, L * C) + val gateSizes = F16MatmulVecHybridProgram.Sizes(B * T, C, FFN, ffnGateOffsetVec4, L * FFN * (C / 4)) + val upSizes = F16MatmulVecHybridProgram.Sizes(B * T, C, FFN, ffnUpOffsetVec4, L * FFN * (C / 4)) + val swiGluSizes = F16SwiGLUProgram.Sizes(B * T * FFN) + val downSizes = F16MatmulVecHybridProgram.Sizes(B * T, FFN, C, ffnDownOffsetVec4, L * C * (FFN / 4)) + + pipeline + .addBufferCopy(l => (l.hidden, l.residual), copySizeBytes) + .addProgram(F16RMSNormProgram.forward(attnNormSizes))( + _ => attnNormSizes, + l => F16RMSNormProgram.ProgramLayout(l.hidden, l.attnNorm, l.attnNormOut), + ) + .addProgram(F16FusedQKVMatmulProgram.forward(qkvSizes))( + _ => qkvSizes, + l => F16FusedQKVMatmulProgram.ProgramLayout(l.wq, l.wk, l.wv, l.attnNormOut, l.q, l.k, l.v), + ) + .addProgram(F16FusedRoPEProgram.forward(fusedRopeSizes))( + _ => fusedRopeSizes, + l => F16FusedRoPEProgram.ProgramLayout(l.q, l.k, l.qRoped, l.kRoped, l.attnParams), + ) + .addProgram(F16FusedKVCacheWriteProgram.forward(fusedKVWriteSizes))( + _ => fusedKVWriteSizes, + l => F16FusedKVCacheWriteProgram.ProgramLayout(l.kRoped, l.v, l.kCache, l.vCache, l.attnParams), + ) + .addProgram(F16AttentionScoresProgram.forward(attnScoresSizes))( + _ => attnScoresSizes, + l => F16AttentionScoresProgram.ProgramLayout(l.qRoped, l.kCache, l.attnScores, l.attnParams), + ) + .addProgram(F16AttentionSoftmaxProgram.forward(attnSoftmaxSizes))( + _ => attnSoftmaxSizes, + l => F16AttentionSoftmaxProgram.ProgramLayout(l.attnScores, l.attnParams), + ) + .addProgram(F16AttentionOutputProgram.forward(attnOutputSizes))( + _ => attnOutputSizes, + l => F16AttentionOutputProgram.ProgramLayout(l.attnScores, l.vCache, l.attnOut, l.attnParams), + ) + .addProgram(F16MatmulVecHybridProgram.forward(woSizes))( + _ => woSizes, + l => F16MatmulVecHybridProgram.ProgramLayout(l.wo, l.attnOut, l.hidden), + ) + .addProgram(F16ResidualAddProgram.forward(resSizes))( + _ => resSizes, + l => F16ResidualAddProgram.ProgramLayout(l.residual, l.hidden, l.attnNormOut), + ) + .addBufferCopy(l => (l.attnNormOut, l.residual), copySizeBytes) + .addProgram(F16RMSNormProgram.forward(ffnNormSizes))( + _ => ffnNormSizes, + l => F16RMSNormProgram.ProgramLayout(l.attnNormOut, l.ffnNorm, l.ffnNormOut), + ) + .addProgram(F16MatmulVecHybridProgram.forward(gateSizes))( + _ => gateSizes, + l => F16MatmulVecHybridProgram.ProgramLayout(l.ffnGate, l.ffnNormOut, l.gate), + ) + .addProgram(F16MatmulVecHybridProgram.forward(upSizes))( + _ => upSizes, + l => F16MatmulVecHybridProgram.ProgramLayout(l.ffnUp, l.ffnNormOut, l.up), + ) + .addProgram(F16SwiGLUProgram.forward(swiGluSizes))( + _ => swiGluSizes, + l => F16SwiGLUProgram.ProgramLayout(l.gate, l.up, l.ffnHidden), + ) + .addProgram(F16MatmulVecHybridProgram.forward(downSizes))( + _ => downSizes, + l => F16MatmulVecHybridProgram.ProgramLayout(l.ffnDown, l.ffnHidden, l.ffnOut), + ) + .addProgram(F16ResidualAddProgram.forward(resSizes))( + _ => resSizes, + l => F16ResidualAddProgram.ProgramLayout(l.residual, l.ffnOut, l.hidden), + ) + + val afterOutput = afterLayers + .addProgram(F16RMSNormProgram.forward(finalNormSizes))( + _ => finalNormSizes, + l => F16RMSNormProgram.ProgramLayout(l.hidden, l.outputNorm, l.attnNormOut), + ) + .addProgram(F16OutputVec4Program.forward(logitsSizes))( + _ => logitsSizes, + l => F16OutputVec4Program.ProgramLayout(l.attnNormOut, l.outputWeight, l.logits), + ) + + if withSampling then + val sampleSizes = F16TopPSampleProgram.Sizes(V) + afterOutput.addProgram(F16TopPSampleProgram.forward(sampleSizes))( + _ => sampleSizes, + l => F16TopPSampleProgram.ProgramLayout(l.logits, l.sampleParams, l.sampledToken), + ) + else + afterOutput \ No newline at end of file diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaPipeline.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaPipeline.scala new file mode 100644 index 00000000..f2186320 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/pipeline/LlamaPipeline.scala @@ -0,0 +1,91 @@ +package io.computenode.cyfra.llama.pipeline + +import io.computenode.cyfra.llama.model.LlamaConfig + +import java.nio.{ByteBuffer, ByteOrder} + +/** Common interface for Llama GPU pipelines. + * + * Defines the standard API for KV-cached inference pipelines: + * - generate: Efficient generation with prefill + decode in single GPU allocation + * + * Implementations: + * - LlamaF16Pipeline: F16 precision with Vec4 optimizations + */ +trait LlamaPipeline: + + /** Model configuration. */ + def config: LlamaConfig + + /** Current sequence length (position in KV cache). */ + def seqLen: Int + + /** Last generation statistics. */ + def lastStats: Option[GenerationStats] + + /** Generate tokens with KV cache (GPU sampling). + * + * Optimized generation that keeps KV cache and sampling on GPU: + * - Prefill: Process all prompt tokens at once + * - Decode: Generate tokens one at a time, attending to full cache + * - Sample: GPU-accelerated top-p sampling with temperature + * + * @param promptTokens Input prompt tokens + * @param maxNewTokens Maximum tokens to generate + * @param temperature Sampling temperature (0 = greedy) + * @param topP Top-p (nucleus) sampling threshold + * @param onToken Callback for each generated token + * @param stopTokens Set of tokens that stop generation + * @param reportStats If true, logs performance stats after generation + * @return Array of generated tokens (not including prompt) + */ + def generate( + promptTokens: Array[Int], + maxNewTokens: Int, + temperature: Float = 0.7f, + topP: Float = 0.9f, + onToken: Int => Unit = _ => (), + stopTokens: Set[Int] = Set.empty, + reportStats: Boolean = false, + ): Array[Int] + + +/** Performance metrics from generation. */ +case class GenerationStats( + promptTokens: Int, + generatedTokens: Int, + prefillTimeMs: Double, + decodeTimeMs: Double, + totalTimeMs: Double, +): + def prefillTokPerSec: Double = if prefillTimeMs > 0 then promptTokens * 1000.0 / prefillTimeMs else 0 + def decodeTokPerSec: Double = if decodeTimeMs > 0 then generatedTokens * 1000.0 / decodeTimeMs else 0 + def totalTokPerSec: Double = if totalTimeMs > 0 then (promptTokens + generatedTokens) * 1000.0 / totalTimeMs else 0 + + override def toString: String = + f"Gen: ${promptTokens}p+${generatedTokens}g, prefill=${prefillTokPerSec}%.0f tok/s, generate=${decodeTokPerSec}%.1f tok/s" + +/** Buffer utilities for pipeline implementations. */ +object PipelineUtils: + + def allocateF32Buffer(floatCount: Int): ByteBuffer = + ByteBuffer.allocateDirect(floatCount * 4).order(ByteOrder.nativeOrder()) + + def allocateF16Buffer(f16Count: Int): ByteBuffer = + ByteBuffer.allocateDirect(f16Count * 2).order(ByteOrder.nativeOrder()) + + def allocateIntBuffer(intCount: Int): ByteBuffer = + ByteBuffer.allocateDirect(intCount * 4).order(ByteOrder.nativeOrder()) + + def copyToF32Buffer(arr: Array[Float], buf: ByteBuffer): Unit = + buf.clear(); buf.asFloatBuffer().put(arr); buf.rewind() + + def copyFromF32Buffer(buf: ByteBuffer, arr: Array[Float]): Unit = + buf.rewind(); buf.asFloatBuffer().get(arr) + + def copyIntToBuffer(arr: Array[Int], buf: ByteBuffer): Unit = + buf.clear(); buf.asIntBuffer().put(arr); buf.rewind() + + def copyF16BytesToBuffer(bytes: Array[Byte], buf: ByteBuffer, offset: Int = 0): Unit = + buf.position(offset) + buf.put(bytes) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16AttentionOutputProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16AttentionOutputProgram.scala new file mode 100644 index 00000000..f7e6b466 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16AttentionOutputProgram.scala @@ -0,0 +1,161 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.binding.GBuffer +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.llama.programs.AttentionParams + +/** Attention output: attn_weights · V + * + * OPTIMIZED with 3D dispatch to eliminate integer divisions: + * - X = dimQuadIdx (16 workgroups for headSize=64) + * - Y = headIdx (32 Q heads) + * - Z = batch * T (1 for decode) + * + * - 32 threads (1 warp) - no cross-warp reduction needed + * - Subgroup reduction only (fast path) + * + * V cache is TRANSPOSED: [layer][head][dim][pos] + * This ensures consecutive threads (different positions) read consecutive memory, + * achieving 100% memory coalescing vs 0.8% with the old layout. + */ +object F16AttentionOutputProgram: + val WARP_SIZE = 32 + val NUM_DIMS = 4 // Process 4 output dimensions per workgroup (Vec4) + val BLOCK_SIZE = WARP_SIZE // 32 threads - single warp + + case class Sizes( + B: Int, + T: Int, + NH: Int, + NKV: Int, + headSize: Int, + maxSeqLen: Int, + vCacheLayerOffset: Int, + L: Int, + ): + def gqaRatio: Int = NH / NKV + def kvSizePerPos: Int = NKV * headSize + def fullCacheSize: Int = L * maxSeqLen * kvSizePerPos + // Size of one dim's position array in transposed V cache + def dimStride: Int = maxSeqLen + // Size of one head's data in transposed V cache + def headStride: Int = headSize * maxSeqLen + // Number of iterations over K positions per thread + def numKIterations: Int = (maxSeqLen + BLOCK_SIZE - 1) / BLOCK_SIZE + // 3D dispatch dimensions + def dispatchX: Int = (headSize + NUM_DIMS - 1) / NUM_DIMS // dimQuad workgroups (16) + def dispatchY: Int = NH // Q heads (32) + def dispatchZ: Int = B * T // batch × positions (1 for decode) + + case class ProgramLayout( + attnWeights: GBuffer[Float32], + vCache: GBuffer[Float16], + output: GBuffer[Float16], + params: GUniform[AttentionParams], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val B = sizes.B + val T = sizes.T + val NH = sizes.NH + val NKV = sizes.NKV + val headSize = sizes.headSize + val gqaRatio = sizes.gqaRatio + val vCacheLayerOffset = sizes.vCacheLayerOffset + val maxSeqLen = sizes.maxSeqLen + val dimStride = sizes.dimStride // = maxSeqLen + val headStride = sizes.headStride // = headSize * maxSeqLen + val numKIterations = sizes.numKIterations + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + attnWeights = GBuffer[Float32](s.B * s.T * s.NH * s.maxSeqLen), + vCache = GBuffer[Float16](s.fullCacheSize), + output = GBuffer[Float16](s.B * s.T * s.NH * s.headSize), + params = GUniform[AttentionParams](), + ), + // 3D dispatch: [dimQuads, heads, batch*T] + dispatch = (_, s) => StaticDispatch((s.dispatchX, s.dispatchY, s.dispatchZ)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + + // 3D workgroup IDs - NO DIVISIONS FOR DISPATCH! + val dimQuadIdx: Int32 = GIO.workgroupId.x + val headIdx: Int32 = GIO.workgroupId.y + val batchPosIdx: Int32 = GIO.workgroupId.z + + val runtimeParams = layout.params.read + val seqLen: Int32 = runtimeParams.seqLen + + for + // KV head from Q head - only ONE division (gqaRatio is compile-time) + kvHeadIdx <- GIO.pure(headIdx / gqaRatio) + + // Output dimensions for this workgroup (4 dims) + outDim0 <- GIO.pure(dimQuadIdx * NUM_DIMS) + + // Flat head index for weights/output addressing + flatHeadIdx <- GIO.pure(batchPosIdx * NH + headIdx) + + // Pre-compute base addresses ONCE + weightsBase <- GIO.pure(flatHeadIdx * maxSeqLen) + outBase <- GIO.pure(flatHeadIdx * headSize + outDim0) + + // V cache base for this KV head (TRANSPOSED layout: [layer][head][dim][pos]) + // Base address: layerOffset + kvHead * headStride + // For dim d: + d * dimStride + // For pos p: + p (consecutive!) + vCacheHeadBase <- GIO.pure((vCacheLayerOffset: Int32) + kvHeadIdx * headStride) + + // Pre-compute base for each of the 4 dims (each dim's positions are consecutive) + vBase0 <- GIO.pure(vCacheHeadBase + outDim0 * dimStride) + vBase1 <- GIO.pure(vCacheHeadBase + (outDim0 + 1) * dimStride) + vBase2 <- GIO.pure(vCacheHeadBase + (outDim0 + 2) * dimStride) + vBase3 <- GIO.pure(vCacheHeadBase + (outDim0 + 3) * dimStride) + + // Hot loop - 32 threads cooperate + // Each thread reads positions tid, tid+32, tid+64, ... + // Consecutive threads read consecutive positions → COALESCED! + localSums <- GIO.pure { + GSeq.gen[Int32](tid, _ + BLOCK_SIZE).limit(numKIterations).unroll.fold(vec4(0.0f, 0.0f, 0.0f, 0.0f), (acc: Vec4[Float32], kPos: Int32) => + // Read weight (softmax output is 0 for invalid positions) + val weight: Float32 = when(kPos < seqLen)(GIO.read[Float32](layout.attnWeights, weightsBase + kPos)).otherwise(0.0f) + + // Read 4 V values from TRANSPOSED cache (4 scalar reads, each coalesced across threads) + // Thread 0 reads pos 0, Thread 1 reads pos 1, ... → consecutive memory! + val v0: Float32 = GIO.read[Float16](layout.vCache, vBase0 + kPos).asFloat32 + val v1: Float32 = GIO.read[Float16](layout.vCache, vBase1 + kPos).asFloat32 + val v2: Float32 = GIO.read[Float16](layout.vCache, vBase2 + kPos).asFloat32 + val v3: Float32 = GIO.read[Float16](layout.vCache, vBase3 + kPos).asFloat32 + + vec4( + acc.x + weight * v0, + acc.y + weight * v1, + acc.z + weight * v2, + acc.w + weight * v3, + ) + ) + } + + // Subgroup reduction - single warp means subgroupAdd gives final result + total <- GIO.pure(vec4( + GIO.subgroupAdd(localSums.x), + GIO.subgroupAdd(localSums.y), + GIO.subgroupAdd(localSums.z), + GIO.subgroupAdd(localSums.w), + )) + + // Thread 0 writes all 4 output values + _ <- GIO.when(tid === 0): + for + _ <- GIO.write[Float16](layout.output, outBase, total.x.asFloat16) + _ <- GIO.write[Float16](layout.output, outBase + 1, total.y.asFloat16) + _ <- GIO.write[Float16](layout.output, outBase + 2, total.z.asFloat16) + _ <- GIO.write[Float16](layout.output, outBase + 3, total.w.asFloat16) + yield GStruct.Empty() + yield GStruct.Empty() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16AttentionScoresProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16AttentionScoresProgram.scala new file mode 100644 index 00000000..fd5ca932 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16AttentionScoresProgram.scala @@ -0,0 +1,108 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.llama.programs.AttentionParams + +/** Attention scores computation: Q · K^T + * + * Computes scaled dot-product attention scores for KV-cached inference. + * One workgroup per (batch, queryPos, head). + * Each thread handles multiple K positions. + * Output is F32 scores buffer for subsequent softmax. + * + * Matches llama.cpp's mul_mat_vec pattern. + */ +object F16AttentionScoresProgram: + val BLOCK_SIZE = 32 + + case class Sizes( + B: Int, + T: Int, + NH: Int, + NKV: Int, + headSize: Int, + maxSeqLen: Int, + kCacheLayerOffset: Int, + L: Int, + ): + def gqaRatio: Int = NH / NKV + def kvSizePerPos: Int = NKV * headSize + def fullCacheSize: Int = L * maxSeqLen * kvSizePerPos + def numIterations: Int = (maxSeqLen + BLOCK_SIZE - 1) / BLOCK_SIZE + + case class ProgramLayout( + q: GBuffer[Float16], + kCache: GBuffer[Float16], + scores: GBuffer[Float32], + params: GUniform[AttentionParams], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + // All compile-time constants + val B = sizes.B + val T = sizes.T + val NH = sizes.NH + val NKV = sizes.NKV + val headSize = sizes.headSize + val gqaRatio = sizes.gqaRatio + val kCacheLayerOffset = sizes.kCacheLayerOffset + val kvSizePerPos = sizes.kvSizePerPos + val maxSeqLen = sizes.maxSeqLen + val numIterations = sizes.numIterations + val scale = 1.0f / math.sqrt(headSize).toFloat + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + q = GBuffer[Float16](s.B * s.T * s.NH * s.headSize), + kCache = GBuffer[Float16](s.fullCacheSize), + scores = GBuffer[Float32](s.B * s.T * s.NH * s.maxSeqLen), + params = GUniform[AttentionParams](), + ), + dispatch = (_, s) => StaticDispatch((s.B * s.T * s.NH, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + + val runtimeParams = layout.params.read + val seqLen: Int32 = runtimeParams.seqLen + val startPos: Int32 = runtimeParams.startPos + + for + // Derive batch/pos/head from workgroup ID - compute once + batchIdx <- GIO.pure(workgroupId / (T * NH)) + remainder <- GIO.pure(workgroupId.mod(T * NH)) + queryPosLocal <- GIO.pure(remainder / NH) + headIdx <- GIO.pure(remainder.mod(NH)) + kvHeadIdx <- GIO.pure(headIdx / gqaRatio) + + queryPosGlobal <- GIO.pure(startPos + queryPosLocal) + + // Base addresses - compute once + qBase <- GIO.pure(batchIdx * (T * NH * headSize) + queryPosLocal * (NH * headSize) + headIdx * headSize) + scoresBase <- GIO.pure(workgroupId * maxSeqLen) + kCacheBase0 <- GIO.pure((kCacheLayerOffset: Int32) + kvHeadIdx * headSize) // Partial - add kPos * kvSizePerPos later + + // Each thread handles multiple K positions + _ <- GIO.foldRepeat[GStruct.Empty](numIterations, GStruct.Empty()): (iter, _) => + val kPos: Int32 = tid + iter * BLOCK_SIZE + + // K cache base for this position + val kCacheBase: Int32 = kCacheBase0 + kPos * kvSizePerPos + + // Compute dot product Q · K[kPos] - unrolled inner loop over headSize + val dot: Float32 = GSeq.gen[Int32](0, _ + 1).limit(headSize).unroll.fold(0.0f, (acc: Float32, d: Int32) => + val qVal: Float32 = GIO.read[Float16](layout.q, qBase + d).asFloat32 + val kVal: Float32 = GIO.read[Float16](layout.kCache, kCacheBase + d).asFloat32 + acc + qVal * kVal + ) + + // Scaled score or -inf for invalid (causal masking + bounds) + val isValid = kPos <= queryPosGlobal && kPos < seqLen + val score: Float32 = when(isValid)(dot * scale).otherwise(-10000.0f) + GIO.write[Float32](layout.scores, scoresBase + kPos, score) + yield GStruct.Empty() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16AttentionSoftmaxProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16AttentionSoftmaxProgram.scala new file mode 100644 index 00000000..45a72c58 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16AttentionSoftmaxProgram.scala @@ -0,0 +1,100 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.llama.programs.AttentionParams + +/** Attention softmax: softmax(scores) in-place + * + * Applies softmax to attention scores with causal masking. + * One workgroup per row (batch, queryPos, head). + * Uses subgroup operations for reduction. + * + * IMPORTANT: Writes 0.0 to invalid positions so output kernel can skip bounds checking. + */ +object F16AttentionSoftmaxProgram: + val BLOCK_SIZE = 32 + + case class Sizes( + B: Int, + T: Int, + NH: Int, + maxSeqLen: Int, + ): + def numRows: Int = B * T * NH + def numIterations: Int = (maxSeqLen + BLOCK_SIZE - 1) / BLOCK_SIZE + + case class ProgramLayout( + scores: GBuffer[Float32], // in-place: [B, T, NH, maxSeqLen] + params: GUniform[AttentionParams], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val B = sizes.B + val T = sizes.T + val NH = sizes.NH + val maxSeqLen = sizes.maxSeqLen + val numIterations = sizes.numIterations + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + scores = GBuffer[Float32](s.B * s.T * s.NH * s.maxSeqLen), + params = GUniform[AttentionParams](), + ), + dispatch = (_, s) => StaticDispatch((s.numRows, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val rowIdx: Int32 = GIO.workgroupId.x + + val runtimeParams = layout.params.read + val seqLen: Int32 = runtimeParams.seqLen + val startPos: Int32 = runtimeParams.startPos + + for + // Derive query position from row index + remainder <- GIO.pure(rowIdx.mod(T * NH)) + queryPosLocal <- GIO.pure(remainder / NH) + queryPosGlobal <- GIO.pure(startPos + queryPosLocal) + rowBase <- GIO.pure(rowIdx * maxSeqLen) + + // Phase 1: Find local max value (only over valid positions) + localMax <- GIO.pure { + GSeq.gen[Int32](tid, _ + BLOCK_SIZE).limit(numIterations).fold(-10000.0f, (maxVal: Float32, col: Int32) => + val isValid = col <= queryPosGlobal && col < seqLen + val score: Float32 = GIO.read[Float32](layout.scores, rowBase + col) + when(isValid)(max(maxVal, score)).otherwise(maxVal) + ) + } + + // Reduce max within warp using subgroupMax + globalMax <- GIO.pure(GIO.subgroupMax(localMax)) + + // Phase 2: Compute exp(x - max) and local sum (only over valid positions) + localSum <- GIO.pure { + GSeq.gen[Int32](tid, _ + BLOCK_SIZE).limit(numIterations).fold(0.0f, (sumVal: Float32, col: Int32) => + val isValid = col <= queryPosGlobal && col < seqLen + val score: Float32 = GIO.read[Float32](layout.scores, rowBase + col) + val expScore: Float32 = exp(score - globalMax) + when(isValid)(sumVal + expScore).otherwise(sumVal) + ) + } + + // Reduce sum within warp using subgroupAdd + globalSum <- GIO.pure(GIO.subgroupAdd(localSum) + 0.0000001f) + rcpSum <- GIO.pure(1.0f / globalSum) + + // Phase 3: Normalize in-place + // Write normalized value for valid positions, 0.0 for invalid + _ <- GIO.foldRepeat[GStruct.Empty](numIterations, GStruct.Empty()): (iter, _) => + val col: Int32 = tid + iter * BLOCK_SIZE + val isValid = col <= queryPosGlobal && col < seqLen + val score: Float32 = GIO.read[Float32](layout.scores, rowBase + col) + val expScore: Float32 = exp(score - globalMax) + // Valid positions get normalized value, invalid get 0.0 + val result: Float32 = when(isValid)(expScore * rcpSum).otherwise(0.0f) + GIO.write[Float32](layout.scores, rowBase + col, result) + yield GStruct.Empty() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16EmbeddingProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16EmbeddingProgram.scala new file mode 100644 index 00000000..2427f19f --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16EmbeddingProgram.scala @@ -0,0 +1,43 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO + +/** F16 token embedding lookup. + * + * Maps token IDs to embedding vectors by index lookup. + */ +object F16EmbeddingProgram: + case class Sizes(seqLen: Int, hiddenSize: Int, vocabSize: Int): + def totalOutputs: Int = seqLen * hiddenSize + + case class ProgramLayout( + tokens: GBuffer[Int32], + embeddings: GBuffer[Float16], + output: GBuffer[Float16], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + tokens = GBuffer[Int32](s.seqLen), + embeddings = GBuffer[Float16](s.vocabSize * s.hiddenSize), + output = GBuffer[Float16](s.totalOutputs), + ), + dispatch = (_, s) => StaticDispatch(((s.totalOutputs + 255) / 256, 1, 1)), + workgroupSize = (256, 1, 1), + ): layout => + val idx = GIO.invocationId + val hiddenSizeVal: Int32 = sizes.hiddenSize + val totalVal: Int32 = sizes.seqLen * sizes.hiddenSize + + GIO.when(idx < totalVal): + val tokenPos = idx / hiddenSizeVal + val dim = idx.mod(hiddenSizeVal) + val tokenId = GIO.read[Int32](layout.tokens, tokenPos) + val embIdx = tokenId * hiddenSizeVal + dim + val value = GIO.read[Float16](layout.embeddings, embIdx) + GIO.write[Float16](layout.output, idx, value) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedKVCacheWriteProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedKVCacheWriteProgram.scala new file mode 100644 index 00000000..56b8198b --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedKVCacheWriteProgram.scala @@ -0,0 +1,113 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.llama.programs.AttentionParams + +/** Fused F16 KV Cache Write for both K and V in single dispatch. + * + * Reduces dispatch count by writing both K and V vectors to cache simultaneously. + * Each invocation copies one element from either K or V input to the cache. + * + * K cache layout: [layer][pos][head][dim] (standard for Q·K dot products) + * V cache layout: [layer][head][dim][pos] (TRANSPOSED for coalesced AttentionOutput reads) + */ +object F16FusedKVCacheWriteProgram: + + case class Sizes( + B: Int, + T: Int, + NKV: Int, + headSize: Int, + maxSeqLen: Int, + layer: Int, + posOffset: Int, + kCacheLayerOffset: Int, + vCacheLayerOffset: Int, + L: Int, + ): + def totalKElements: Int = B * T * NKV * headSize + def totalVElements: Int = B * T * NKV * headSize + def totalElements: Int = totalKElements + totalVElements + def kvSizePerPos: Int = NKV * headSize + def fullCacheSize: Int = L * maxSeqLen * kvSizePerPos + + case class ProgramLayout( + k: GBuffer[Float16], + v: GBuffer[Float16], + kCache: GBuffer[Float16], + vCache: GBuffer[Float16], + params: GUniform[AttentionParams], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val B = sizes.B + val T = sizes.T + val NKV = sizes.NKV + val headSize = sizes.headSize + val maxSeqLen = sizes.maxSeqLen + val totalKElements = sizes.totalKElements + val totalElements = sizes.totalElements + val kCacheLayerOffset = sizes.kCacheLayerOffset + val vCacheLayerOffset = sizes.vCacheLayerOffset + val kvSizePerPos = sizes.kvSizePerPos + val fullCacheSize = sizes.fullCacheSize + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + k = GBuffer[Float16](s.totalKElements), + v = GBuffer[Float16](s.totalVElements), + kCache = GBuffer[Float16](s.fullCacheSize), + vCache = GBuffer[Float16](s.fullCacheSize), + params = GUniform[AttentionParams](), + ), + dispatch = (_, s) => StaticDispatch(((s.totalElements + 255) / 256, 1, 1)), + workgroupSize = (256, 1, 1), + ): layout => + val idx = GIO.invocationId + val posOffsetVal: Int32 = layout.params.read.startPos + + val Tval: Int32 = T + val NKVval: Int32 = NKV + val headSizeVal: Int32 = headSize + val maxSeqLenVal: Int32 = maxSeqLen + val totalKElementsVal: Int32 = totalKElements + val totalElementsVal: Int32 = totalElements + val kCacheLayerOffsetVal: Int32 = kCacheLayerOffset + val vCacheLayerOffsetVal: Int32 = vCacheLayerOffset + val kvSizePerPosVal: Int32 = kvSizePerPos + + GIO.when(idx < totalElementsVal): + // Determine if this is K or V + val isK = idx < totalKElementsVal + val localIdx: Int32 = when(isK)(idx).otherwise(idx - totalKElementsVal) + + // Decompose index + val elementsPerBatch = Tval * NKVval * headSizeVal + val b = localIdx / elementsPerBatch + val remaining1 = localIdx.mod(elementsPerBatch) + val elementsPerPos = NKVval * headSizeVal + val t = remaining1 / elementsPerPos + val remaining2 = remaining1.mod(elementsPerPos) + val h = remaining2 / headSizeVal + val d = remaining2.mod(headSizeVal) + + val cachePos = posOffsetVal + t + // K cache: [layer][pos][head][dim] - standard layout + val kCacheOffset: Int32 = cachePos * kvSizePerPosVal + h * headSizeVal + d + // V cache: [layer][head][dim][pos] - TRANSPOSED for coalesced reads + val vCacheOffset: Int32 = h * headSizeVal * maxSeqLenVal + d * maxSeqLenVal + cachePos + + for + _ <- GIO.when(isK): + val kVal = GIO.read[Float16](layout.k, localIdx) + val cacheIdx = kCacheLayerOffsetVal + kCacheOffset + GIO.write[Float16](layout.kCache, cacheIdx, kVal) + _ <- GIO.when(!isK): + val vVal = GIO.read[Float16](layout.v, localIdx) + val cacheIdx = vCacheLayerOffsetVal + vCacheOffset + GIO.write[Float16](layout.vCache, cacheIdx, vVal) + yield GStruct.Empty() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedQKVMatmulProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedQKVMatmulProgram.scala new file mode 100644 index 00000000..5a9ce4e1 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedQKVMatmulProgram.scala @@ -0,0 +1,149 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO + +/** Fused F16 Q/K/V projection in single dispatch. + * + * Computes all three attention projections at once: + * - Q = input @ Wq (outFeatures = C) + * - K = input @ Wk (outFeatures = kvSize) + * - V = input @ Wv (outFeatures = kvSize) + * + * Reduces 3 dispatches → 1, sharing the same input read. + * Each warp computes one output element from Q, K, or V. + */ +object F16FusedQKVMatmulProgram: + val WARP_SIZE = 32 + val WARPS_PER_WORKGROUP = 8 + val BLOCK_SIZE = WARP_SIZE * WARPS_PER_WORKGROUP + + case class Sizes( + batchSize: Int, + inFeatures: Int, // C (hidden size) + qOutFeatures: Int, // C (for Q) + kvOutFeatures: Int, // kvSize = NKV * headSize (for K and V) + wqOffsetVec4: Int, + wkOffsetVec4: Int, + wvOffsetVec4: Int, + totalWqVec4: Int, + totalWkVec4: Int, + totalWvVec4: Int, + ): + require(inFeatures % 4 == 0, s"inFeatures ($inFeatures) must be divisible by 4") + def inFeaturesDiv4: Int = inFeatures / 4 + def totalQOutputs: Int = batchSize * qOutFeatures + def totalKOutputs: Int = batchSize * kvOutFeatures + def totalVOutputs: Int = batchSize * kvOutFeatures + def totalOutputs: Int = totalQOutputs + totalKOutputs + totalVOutputs + def numWorkgroups: Int = (totalOutputs + WARPS_PER_WORKGROUP - 1) / WARPS_PER_WORKGROUP + def numVecIterations: Int = (inFeaturesDiv4 + WARP_SIZE - 1) / WARP_SIZE + + case class ProgramLayout( + wq: GBuffer[Vec4[Float16]], + wk: GBuffer[Vec4[Float16]], + wv: GBuffer[Vec4[Float16]], + input: GBuffer[Float16], + q: GBuffer[Float16], + k: GBuffer[Float16], + v: GBuffer[Float16], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val inFeatures = sizes.inFeatures + val inFeaturesDiv4 = sizes.inFeaturesDiv4 + val qOutFeatures = sizes.qOutFeatures + val kvOutFeatures = sizes.kvOutFeatures + val wqOffsetVec4 = sizes.wqOffsetVec4 + val wkOffsetVec4 = sizes.wkOffsetVec4 + val wvOffsetVec4 = sizes.wvOffsetVec4 + val numVecIterations = sizes.numVecIterations + val totalQOutputs = sizes.totalQOutputs + val totalKOutputs = sizes.totalKOutputs + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + wq = GBuffer[Vec4[Float16]](s.totalWqVec4), + wk = GBuffer[Vec4[Float16]](s.totalWkVec4), + wv = GBuffer[Vec4[Float16]](s.totalWvVec4), + input = GBuffer[Float16](s.batchSize * s.inFeatures), + q = GBuffer[Float16](s.totalQOutputs), + k = GBuffer[Float16](s.totalKOutputs), + v = GBuffer[Float16](s.totalVOutputs), + ), + dispatch = (_, s) => StaticDispatch((s.numWorkgroups, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + val laneId = tid.mod(WARP_SIZE) + val warpId = tid / WARP_SIZE + + val inFeaturesVal: Int32 = inFeatures + val inFeaturesDiv4Val: Int32 = inFeaturesDiv4 + val qOutFeaturesVal: Int32 = qOutFeatures + val kvOutFeaturesVal: Int32 = kvOutFeatures + val wqOffsetVec4Val: Int32 = wqOffsetVec4 + val wkOffsetVec4Val: Int32 = wkOffsetVec4 + val wvOffsetVec4Val: Int32 = wvOffsetVec4 + val totalQOutputsVal: Int32 = totalQOutputs + val totalKOutputsVal: Int32 = totalKOutputs + val totalOutputsVal: Int32 = sizes.totalOutputs + + val globalOutputIdx = workgroupId * WARPS_PER_WORKGROUP + warpId + + // Determine which output (Q, K, or V) this warp handles + // Q: indices [0, totalQOutputs) + // K: indices [totalQOutputs, totalQOutputs + totalKOutputs) + // V: indices [totalQOutputs + totalKOutputs, total) + + val isQ = globalOutputIdx < totalQOutputsVal + val isK = !isQ && (globalOutputIdx < totalQOutputsVal + totalKOutputsVal) + val isV = !isQ && !isK + + // Helper function to compute matmul for a given buffer and offset + def computeMatmul( + weightBuffer: GBuffer[Vec4[Float16]], + weightOffset: Int32, + outFeatures: Int32, + localIdx: Int32, + ): Float32 = + val batch = localIdx / outFeatures + val outIdx = localIdx.mod(outFeatures) + val localSum = GSeq + .gen[Int32](laneId, _ + WARP_SIZE) + .limit(numVecIterations) + .unroll + .fold(0.0f, (sum: Float32, k: Int32) => + when(k < inFeaturesDiv4Val): + val wVec = GIO.read[Vec4[Float16]](weightBuffer, weightOffset + outIdx * inFeaturesDiv4Val + k) + val inputBase = batch * inFeaturesVal + k * 4 + val x0 = GIO.read[Float16](layout.input, inputBase).asFloat32 + val x1 = GIO.read[Float16](layout.input, inputBase + 1).asFloat32 + val x2 = GIO.read[Float16](layout.input, inputBase + 2).asFloat32 + val x3 = GIO.read[Float16](layout.input, inputBase + 3).asFloat32 + sum + wVec.x.asFloat32 * x0 + wVec.y.asFloat32 * x1 + wVec.z.asFloat32 * x2 + wVec.w.asFloat32 * x3 + .otherwise(sum) + ) + GIO.subgroupAdd(localSum) + + // Process Q outputs + for + _ <- GIO.when(isQ && globalOutputIdx < totalOutputsVal): + val localIdx = globalOutputIdx + val result = computeMatmul(layout.wq, wqOffsetVec4Val, qOutFeaturesVal, localIdx) + GIO.write[Float16](layout.q, localIdx, result.asFloat16) + // Process K outputs + _ <- GIO.when(isK && globalOutputIdx < totalOutputsVal): + val localIdx = globalOutputIdx - totalQOutputsVal + val result = computeMatmul(layout.wk, wkOffsetVec4Val, kvOutFeaturesVal, localIdx) + GIO.write[Float16](layout.k, localIdx, result.asFloat16) + // Process V outputs + _ <- GIO.when(isV && globalOutputIdx < totalOutputsVal): + val localIdx = globalOutputIdx - totalQOutputsVal - totalKOutputsVal + val result = computeMatmul(layout.wv, wvOffsetVec4Val, kvOutFeaturesVal, localIdx) + GIO.write[Float16](layout.v, localIdx, result.asFloat16) + yield GStruct.Empty() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedRoPEProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedRoPEProgram.scala new file mode 100644 index 00000000..2e068281 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16FusedRoPEProgram.scala @@ -0,0 +1,124 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.llama.programs.AttentionParams + +/** Fused F16 Rotary Position Embedding for both Q and K in single dispatch. + * + * Reduces dispatch count by processing both Q and K tensors simultaneously. + * Each invocation handles one pair from either Q or K. + */ +object F16FusedRoPEProgram: + val BLOCK_SIZE = 256 + + case class Sizes( + B: Int, + T: Int, + numHeadsQ: Int, + numHeadsK: Int, + headSize: Int, + theta: Float, + ): + def totalQPairs: Int = B * T * numHeadsQ * (headSize / 2) + def totalKPairs: Int = B * T * numHeadsK * (headSize / 2) + def totalPairs: Int = totalQPairs + totalKPairs + + case class ProgramLayout( + qIn: GBuffer[Float16], + kIn: GBuffer[Float16], + qOut: GBuffer[Float16], + kOut: GBuffer[Float16], + params: GUniform[AttentionParams], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val B = sizes.B + val T = sizes.T + val numHeadsQ = sizes.numHeadsQ + val numHeadsK = sizes.numHeadsK + val headSize = sizes.headSize + val theta = sizes.theta + val totalQPairs = sizes.totalQPairs + val totalKPairs = sizes.totalKPairs + val totalPairs = sizes.totalPairs + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + qIn = GBuffer[Float16](s.B * s.T * s.numHeadsQ * s.headSize), + kIn = GBuffer[Float16](s.B * s.T * s.numHeadsK * s.headSize), + qOut = GBuffer[Float16](s.B * s.T * s.numHeadsQ * s.headSize), + kOut = GBuffer[Float16](s.B * s.T * s.numHeadsK * s.headSize), + params = GUniform[AttentionParams](), + ), + dispatch = (_, s) => StaticDispatch(((s.totalPairs + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val idx = GIO.invocationId + val totalPairsVal: Int32 = totalPairs + val totalQPairsVal: Int32 = totalQPairs + val Tval: Int32 = T + val numHeadsQVal: Int32 = numHeadsQ + val numHeadsKVal: Int32 = numHeadsK + val halfHead: Int32 = headSize / 2 + val headSizeVal: Int32 = headSize + val thetaVal: Float32 = theta + val startPosVal: Int32 = layout.params.read.startPos + + GIO.when(idx < totalPairsVal): + // Determine if this is Q or K based on index + val isQ = idx < totalQPairsVal + + // Calculate local index within Q or K + val localIdx: Int32 = when(isQ)(idx).otherwise(idx - totalQPairsVal) + val numHeads: Int32 = when(isQ)(numHeadsQVal).otherwise(numHeadsKVal) + + // Decompose index + val perHead = halfHead + val perPos = numHeads * halfHead + val perBatch = Tval * perPos + + val b = localIdx / perBatch + val rem1 = localIdx.mod(perBatch) + val t = rem1 / perPos + val rem2 = rem1.mod(perPos) + val h = rem2 / perHead + val d = rem2.mod(perHead) + + // Compute RoPE rotation + val pos = startPosVal + t + val headSizeFloat: Float32 = headSize.toFloat + val freqExponent: Float32 = -2.0f * d.asFloat / headSizeFloat + val freq: Float32 = pos.asFloat * pow(thetaVal, freqExponent) + val cosFreq = cos(freq).asFloat16 + val sinFreq = sin(freq).asFloat16 + + // Calculate full indices for reading/writing + val fullIdx: Int32 = b * Tval * numHeads * headSizeVal + t * numHeads * headSizeVal + h * headSizeVal + val idx0 = fullIdx + d * 2 + val idx1 = idx0 + 1 + + // Read, rotate, write - branch on Q vs K + for + _ <- GIO.when(isQ): + val x0 = GIO.read[Float16](layout.qIn, idx0) + val x1 = GIO.read[Float16](layout.qIn, idx1) + val y0 = x0 * cosFreq - x1 * sinFreq + val y1 = x0 * sinFreq + x1 * cosFreq + for + _ <- GIO.write[Float16](layout.qOut, idx0, y0) + _ <- GIO.write[Float16](layout.qOut, idx1, y1) + yield GStruct.Empty() + _ <- GIO.when(!isQ): + val x0 = GIO.read[Float16](layout.kIn, idx0) + val x1 = GIO.read[Float16](layout.kIn, idx1) + val y0 = x0 * cosFreq - x1 * sinFreq + val y1 = x0 * sinFreq + x1 * cosFreq + for + _ <- GIO.write[Float16](layout.kOut, idx0, y0) + _ <- GIO.write[Float16](layout.kOut, idx1, y1) + yield GStruct.Empty() + yield GStruct.Empty() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16MatmulVecHybridProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16MatmulVecHybridProgram.scala new file mode 100644 index 00000000..ce773650 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16MatmulVecHybridProgram.scala @@ -0,0 +1,215 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.binding.{GBuffer, GShared} +import io.computenode.cyfra.dsl.gio.GIO + +/** F16 matrix-vector multiply with Vec4-packed weights. + * + * Uses Vec4[Float16] weights for 4x memory bandwidth while keeping scalar input. + * Optimal for activation-weight multiplies where weights are static but activations vary. + * + * Optimized to compute 2 output rows in PARALLEL within the same loop iteration, + * reading input once and reusing for both rows (like llama.cpp). + * + * Uses 128 threads (4 warps) with shared memory reduction for better occupancy. + * + * @note Requires `inFeatures` divisible by 4 for Vec4 alignment. + */ +object F16MatmulVecHybridProgram: + val WARP_SIZE = 32 + val NUM_WARPS = 4 + val NUM_ROWS = 2 // Each workgroup computes 2 output rows in parallel + val BLOCK_SIZE = WARP_SIZE * NUM_WARPS // 128 threads + + case class Sizes( + batchSize: Int, + inFeatures: Int, + outFeatures: Int, + weightOffsetVec4: Int = 0, + totalWeightVec4: Int = -1, + ): + require(inFeatures % 4 == 0, s"inFeatures ($inFeatures) must be divisible by 4") + def inFeaturesDiv4: Int = inFeatures / 4 + def totalOutputs: Int = batchSize * outFeatures + def numWorkgroups: Int = (totalOutputs + NUM_ROWS - 1) / NUM_ROWS + // Exact iterations - no bounds check needed if inFeatures is multiple of (BLOCK_SIZE * 4) + def numVecIterations: Int = inFeaturesDiv4 / BLOCK_SIZE + def actualWeightVec4: Int = if totalWeightVec4 < 0 then outFeatures * inFeaturesDiv4 else totalWeightVec4 + + case class ProgramLayout( + weight: GBuffer[Vec4[Float16]], + input: GBuffer[Float16], + output: GBuffer[Float16], + ) derives Layout + + /** Layout with Vec4 input for optimal memory bandwidth. */ + case class ProgramLayoutVec4( + weight: GBuffer[Vec4[Float16]], + input: GBuffer[Vec4[Float16]], + output: GBuffer[Float16], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + // Compile-time constants (embedded in shader as literals) + val inFeatures = sizes.inFeatures + val inFeaturesDiv4 = sizes.inFeaturesDiv4 + val outFeatures = sizes.outFeatures + val weightOffsetVec4 = sizes.weightOffsetVec4 + val numVecIterations = sizes.numVecIterations + val totalOutputs = sizes.totalOutputs + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + weight = GBuffer[Vec4[Float16]](s.actualWeightVec4), + input = GBuffer[Float16](s.batchSize * s.inFeatures), + output = GBuffer[Float16](s.totalOutputs), + ), + dispatch = (_, s) => StaticDispatch((s.numWorkgroups, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + val laneId: Int32 = tid.mod(WARP_SIZE) + val warpId: Int32 = tid / WARP_SIZE + + // Shared memory for partial sums from each warp + val sharedSums = GShared[Vec2[Float32]](NUM_WARPS) + + // Use for-comprehension with GIO.pure to force index computations BEFORE the loop + for + // Pre-compute all indices outside the hot loop + firstRow <- GIO.pure(workgroupId * NUM_ROWS) + batch <- GIO.pure(firstRow / outFeatures) + inputBase <- GIO.pure(batch * inFeatures) + outIdx0 <- GIO.pure(firstRow - batch * outFeatures) + outIdx1 <- GIO.pure(when(outIdx0 + 1 < outFeatures)(outIdx0 + 1).otherwise(0)) + weightOffset <- GIO.pure(weightOffsetVec4: Int32) + inFeatDiv4 <- GIO.pure(inFeaturesDiv4: Int32) + wBase0 <- GIO.pure(weightOffset + outIdx0 * inFeatDiv4) + wBase1 <- GIO.pure(weightOffset + outIdx1 * inFeatDiv4) + + // Hot loop - NO BOUNDS CHECK (assume inFeatures is multiple of BLOCK_SIZE*4) + // This matches llama.cpp's fast path pattern + localSums <- GIO.pure(GSeq + .gen[Int32](tid, _ + BLOCK_SIZE) + .limit(numVecIterations) + .unroll + .fold(vec2(0.0f, 0.0f), (acc: Vec2[Float32], k: Int32) => + // Read 4 input values and construct Vec4 + val iBase = inputBase + k * 4 + val xVec = vec4( + GIO.read[Float16](layout.input, iBase).asFloat32, + GIO.read[Float16](layout.input, iBase + 1).asFloat32, + GIO.read[Float16](layout.input, iBase + 2).asFloat32, + GIO.read[Float16](layout.input, iBase + 3).asFloat32, + ) + + // Read weights for both rows + val wVec0 = GIO.read[Vec4[Float16]](layout.weight, wBase0 + k) + val wVec1 = GIO.read[Vec4[Float16]](layout.weight, wBase1 + k) + + // Compute both dot products using vec4 operations + vec2(acc.x + wVec0.asVec4F32.dot(xVec), acc.y + wVec1.asVec4F32.dot(xVec)) + )) + + // First reduce within each warp using subgroupAdd + warpSum <- GIO.pure(vec2(GIO.subgroupAdd(localSums.x), GIO.subgroupAdd(localSums.y))) + + // Lane 0 of each warp writes to shared memory + _ <- GIO.when(laneId === 0): + sharedSums.write(warpId, warpSum) + _ <- GIO.barrier + + // All threads read shared memory (but only tid=0 will write) + sum0 <- GIO.pure(sharedSums.read(0)) + sum1 <- GIO.pure(sharedSums.read(1)) + sum2 <- GIO.pure(sharedSums.read(2)) + sum3 <- GIO.pure(sharedSums.read(3)) + total0 <- GIO.pure(sum0.x + sum1.x + sum2.x + sum3.x) + total1 <- GIO.pure(sum0.y + sum1.y + sum2.y + sum3.y) + + // Thread 0 writes output + _ <- GIO.when(tid === 0): + GIO.write[Float16](layout.output, firstRow, total0.asFloat16) + _ <- GIO.when(tid === 0 && firstRow + 1 < totalOutputs): + GIO.write[Float16](layout.output, firstRow + 1, total1.asFloat16) + yield GStruct.Empty() + + /** Optimized forward with Vec4 input reads - 4x fewer memory transactions. */ + def forwardVec4(sizes: Sizes): GProgram[Sizes, ProgramLayoutVec4] = + val inFeaturesDiv4 = sizes.inFeaturesDiv4 + val outFeatures = sizes.outFeatures + val weightOffsetVec4 = sizes.weightOffsetVec4 + val numVecIterations = sizes.numVecIterations + val totalOutputs = sizes.totalOutputs + + GProgram[Sizes, ProgramLayoutVec4]( + layout = s => ProgramLayoutVec4( + weight = GBuffer[Vec4[Float16]](s.actualWeightVec4), + input = GBuffer[Vec4[Float16]](s.batchSize * s.inFeaturesDiv4), + output = GBuffer[Float16](s.totalOutputs), + ), + dispatch = (_, s) => StaticDispatch((s.numWorkgroups, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + val laneId: Int32 = tid.mod(WARP_SIZE) + val warpId: Int32 = tid / WARP_SIZE + + val sharedSums = GShared[Vec2[Float32]](NUM_WARPS) + + for + // Pre-compute all indices outside the hot loop + firstRow <- GIO.pure(workgroupId * NUM_ROWS) + batch <- GIO.pure(firstRow / outFeatures) + inputBase <- GIO.pure(batch * inFeaturesDiv4) + outIdx0 <- GIO.pure(firstRow - batch * outFeatures) + outIdx1 <- GIO.pure(when(outIdx0 + 1 < outFeatures)(outIdx0 + 1).otherwise(0)) + weightOffset <- GIO.pure(weightOffsetVec4: Int32) + inFeatDiv4 <- GIO.pure(inFeaturesDiv4: Int32) + wBase0 <- GIO.pure(weightOffset + outIdx0 * inFeatDiv4) + wBase1 <- GIO.pure(weightOffset + outIdx1 * inFeatDiv4) + + // Hot loop - NO BOUNDS CHECK + localSums <- GIO.pure(GSeq + .gen[Int32](tid, _ + BLOCK_SIZE) + .limit(numVecIterations) + .unroll + .fold(vec2(0.0f, 0.0f), (acc: Vec2[Float32], k: Int32) => + // Read input ONCE as Vec4 + val xVec = GIO.read[Vec4[Float16]](layout.input, inputBase + k).asVec4F32 + + // Read weights for both rows + val wVec0 = GIO.read[Vec4[Float16]](layout.weight, wBase0 + k).asVec4F32 + val wVec1 = GIO.read[Vec4[Float16]](layout.weight, wBase1 + k).asVec4F32 + + vec2(acc.x + wVec0.dot(xVec), acc.y + wVec1.dot(xVec)) + )) + + // First reduce within each warp + warpSum <- GIO.pure(vec2(GIO.subgroupAdd(localSums.x), GIO.subgroupAdd(localSums.y))) + + // Lane 0 of each warp writes to shared memory + _ <- GIO.when(laneId === 0): + sharedSums.write(warpId, warpSum) + _ <- GIO.barrier + + // All threads read shared memory (but only tid=0 will write) + sum0 <- GIO.pure(sharedSums.read(0)) + sum1 <- GIO.pure(sharedSums.read(1)) + sum2 <- GIO.pure(sharedSums.read(2)) + sum3 <- GIO.pure(sharedSums.read(3)) + total0 <- GIO.pure(sum0.x + sum1.x + sum2.x + sum3.x) + total1 <- GIO.pure(sum0.y + sum1.y + sum2.y + sum3.y) + + // Thread 0 writes output + _ <- GIO.when(tid === 0): + GIO.write[Float16](layout.output, firstRow, total0.asFloat16) + _ <- GIO.when(tid === 0 && firstRow + 1 < totalOutputs): + GIO.write[Float16](layout.output, firstRow + 1, total1.asFloat16) + yield GStruct.Empty() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16OutputVec4Program.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16OutputVec4Program.scala new file mode 100644 index 00000000..29c5b5a0 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16OutputVec4Program.scala @@ -0,0 +1,104 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.GStruct + +/** F16 output projection with Vec4-packed weights. + * + * Projects hidden states to vocabulary logits using Vec4[Float16] weights. + * Output is F32 for softmax numerical stability. + * + * Optimized with NUM_ROWS=4: each workgroup computes 4 output rows. + * + * @note Requires `hiddenSize` divisible by 4 for Vec4 alignment. + */ +object F16OutputVec4Program: + val WARP_SIZE = 32 + val NUM_ROWS = 4 // Each workgroup computes 4 rows + val BLOCK_SIZE = WARP_SIZE // Single warp per workgroup + + case class Sizes(batchSize: Int, hiddenSize: Int, vocabSize: Int): + require(hiddenSize % 4 == 0, s"hiddenSize ($hiddenSize) must be divisible by 4") + def hiddenSizeDiv4: Int = hiddenSize / 4 + def totalOutputs: Int = batchSize * vocabSize + def numWorkgroups: Int = (totalOutputs + NUM_ROWS - 1) / NUM_ROWS + def numVecIterations: Int = (hiddenSizeDiv4 + WARP_SIZE - 1) / WARP_SIZE + + case class ProgramLayout( + input: GBuffer[Float16], + weight: GBuffer[Vec4[Float16]], + output: GBuffer[Float32], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val hiddenSize = sizes.hiddenSize + val hiddenSizeDiv4 = sizes.hiddenSizeDiv4 + val vocabSize = sizes.vocabSize + val numVecIterations = sizes.numVecIterations + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + input = GBuffer[Float16](s.batchSize * s.hiddenSize), + weight = GBuffer[Vec4[Float16]](s.vocabSize * s.hiddenSizeDiv4), + output = GBuffer[Float32](s.totalOutputs), + ), + dispatch = (_, s) => StaticDispatch((s.numWorkgroups, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val workgroupId: Int32 = GIO.workgroupId.x + val hiddenSizeVal: Int32 = hiddenSize + val hiddenSizeDiv4Val: Int32 = hiddenSizeDiv4 + val vocabSizeVal: Int32 = vocabSize + val totalOutputsVal: Int32 = sizes.totalOutputs + + // Each workgroup computes NUM_ROWS consecutive output rows + val firstRow = workgroupId * NUM_ROWS + val batch = firstRow / vocabSizeVal + val inputBase0 = batch * hiddenSizeVal + + // Helper to compute one row's dot product + def computeRow(vocabIdx: Int32): Float32 = + val localSum = GSeq + .gen[Int32](tid, _ + WARP_SIZE) + .limit(numVecIterations) + .unroll + .fold(0.0f, (sum: Float32, k: Int32) => + when(k < hiddenSizeDiv4Val): + val wVec = GIO.read[Vec4[Float16]](layout.weight, vocabIdx * hiddenSizeDiv4Val + k) + val inputBase = inputBase0 + k * 4 + val x0 = GIO.read[Float16](layout.input, inputBase).asFloat32 + val x1 = GIO.read[Float16](layout.input, inputBase + 1).asFloat32 + val x2 = GIO.read[Float16](layout.input, inputBase + 2).asFloat32 + val x3 = GIO.read[Float16](layout.input, inputBase + 3).asFloat32 + sum + wVec.x.asFloat32 * x0 + wVec.y.asFloat32 * x1 + wVec.z.asFloat32 * x2 + wVec.w.asFloat32 * x3 + .otherwise(sum) + ) + GIO.subgroupAdd(localSum) + + // Compute all 4 rows + val vocabIdx0 = firstRow.mod(vocabSizeVal) + val vocabIdx1 = (firstRow + 1).mod(vocabSizeVal) + val vocabIdx2 = (firstRow + 2).mod(vocabSizeVal) + val vocabIdx3 = (firstRow + 3).mod(vocabSizeVal) + + val sum0 = computeRow(vocabIdx0) + val sum1 = computeRow(vocabIdx1) + val sum2 = computeRow(vocabIdx2) + val sum3 = computeRow(vocabIdx3) + + // Write results + for + _ <- GIO.when(firstRow < totalOutputsVal): + GIO.write[Float32](layout.output, firstRow, sum0) + _ <- GIO.when(firstRow + 1 < totalOutputsVal): + GIO.write[Float32](layout.output, firstRow + 1, sum1) + _ <- GIO.when(firstRow + 2 < totalOutputsVal): + GIO.write[Float32](layout.output, firstRow + 2, sum2) + _ <- GIO.when(firstRow + 3 < totalOutputsVal): + GIO.write[Float32](layout.output, firstRow + 3, sum3) + yield GStruct.Empty() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16RMSNormProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16RMSNormProgram.scala new file mode 100644 index 00000000..ea6ccdb2 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16RMSNormProgram.scala @@ -0,0 +1,108 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.binding.GShared + +/** F16 Root Mean Square Layer Normalization. + * + * Normalizes input by RMS: `output[i] = input[i] / rms(input) * weight[i]`. + * Accumulates in F32 for numerical precision. + */ +object F16RMSNormProgram: + val BLOCK_SIZE = 512 + + case class Sizes( + numRows: Int, + rowSize: Int, + eps: Float, + weightOffset: Int = 0, + totalWeightSize: Int = -1, + ): + def numIterations: Int = (rowSize + BLOCK_SIZE - 1) / BLOCK_SIZE + def actualWeightSize: Int = if totalWeightSize < 0 then rowSize else totalWeightSize + + case class ProgramLayout( + input: GBuffer[Float16], + weight: GBuffer[Float16], + output: GBuffer[Float16], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val rowSize = sizes.rowSize + val eps = sizes.eps + val numIterations = sizes.numIterations + val weightOffset = sizes.weightOffset + + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + input = GBuffer[Float16](s.numRows * s.rowSize), + weight = GBuffer[Float16](s.actualWeightSize), + output = GBuffer[Float16](s.numRows * s.rowSize), + ), + dispatch = (_, s) => StaticDispatch((s.numRows, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid = GIO.localInvocationIndex + val row = GIO.workgroupId.x + val shared = GShared[Float32](BLOCK_SIZE) + + val rowSizeVal: Int32 = rowSize + val epsVal: Float32 = eps + val weightOffsetVal: Int32 = weightOffset + val baseIdx = row * rowSizeVal + + // Phase 1: Each thread sums its strided elements (pure DSL expression) + val localSum = GSeq + .gen[Int32](tid, _ + BLOCK_SIZE) + .limit(numIterations) + .fold(0.0f, (sum: Float32, col: Int32) => + when(col < rowSizeVal): + val x = GIO.read[Float16](layout.input, baseIdx + col).asFloat32 + sum + (x * x) + .otherwise(sum) + ) + + for + // Write local sum to shared memory + _ <- shared.write(tid, localSum) + _ <- GIO.barrier + + // Tree reduction: 9 levels for 512 threads + _ <- reduceShared(shared, tid, 256) + _ <- reduceShared(shared, tid, 128) + _ <- reduceShared(shared, tid, 64) + _ <- reduceShared(shared, tid, 32) + _ <- reduceShared(shared, tid, 16) + _ <- reduceShared(shared, tid, 8) + _ <- reduceShared(shared, tid, 4) + _ <- reduceShared(shared, tid, 2) + _ <- reduceShared(shared, tid, 1) + + // Phase 3: Compute scale and write output + _ <- { + val totalSum = shared.read(0) + val scale: Float32 = 1.0f / sqrt((totalSum / rowSizeVal.asFloat) + epsVal) + + GIO.repeat(numIterations): iter => + val col = tid + iter * BLOCK_SIZE + GIO.when(col < rowSizeVal): + val x = GIO.read[Float16](layout.input, baseIdx + col).asFloat32 + val w = GIO.read[Float16](layout.weight, weightOffsetVal + col).asFloat32 + GIO.write[Float16](layout.output, baseIdx + col, (x * scale * w).asFloat16) + } + yield GStruct.Empty() + + // Helper for one level of tree reduction + private def reduceShared(shared: GShared[Float32], tid: Int32, stride: Int): GIO[GStruct.Empty] = + val strideVal: Int32 = stride + for + _ <- GIO.when(tid < strideVal): + val current = shared.read(tid) + val other = shared.read(tid + strideVal) + shared.write(tid, current + other) + _ <- GIO.barrier + yield GStruct.Empty() \ No newline at end of file diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16ResidualAddProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16ResidualAddProgram.scala new file mode 100644 index 00000000..182c4948 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16ResidualAddProgram.scala @@ -0,0 +1,33 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO + +/** F16 element-wise addition for residual connections: `output = a + b`. */ +object F16ResidualAddProgram: + case class Sizes(size: Int) + + case class ProgramLayout( + a: GBuffer[Float16], + b: GBuffer[Float16], + output: GBuffer[Float16], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + a = GBuffer[Float16](s.size), + b = GBuffer[Float16](s.size), + output = GBuffer[Float16](s.size), + ), + dispatch = (_, s) => StaticDispatch(((s.size + 255) / 256, 1, 1)), + workgroupSize = (256, 1, 1), + ): layout => + val idx = GIO.invocationId + GIO.when(idx < sizes.size): + val aVal = GIO.read[Float16](layout.a, idx) + val bVal = GIO.read[Float16](layout.b, idx) + GIO.write[Float16](layout.output, idx, aVal + bVal) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16SwiGLUProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16SwiGLUProgram.scala new file mode 100644 index 00000000..b7499baf --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16SwiGLUProgram.scala @@ -0,0 +1,39 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.gio.GIO + +/** F16 SwiGLU activation: `SiLU(gate) * up`. + * + * Combines gated linear unit with SiLU (swish) activation. + * Computes in F32 internally for precision. + */ +object F16SwiGLUProgram: + case class Sizes(numElements: Int) + + case class ProgramLayout( + gate: GBuffer[Float16], + up: GBuffer[Float16], + output: GBuffer[Float16], + ) derives Layout + + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + GProgram[Sizes, ProgramLayout]( + layout = s => ProgramLayout( + gate = GBuffer[Float16](s.numElements), + up = GBuffer[Float16](s.numElements), + output = GBuffer[Float16](s.numElements), + ), + dispatch = (_, s) => StaticDispatch(((s.numElements + 255) / 256, 1, 1)), + workgroupSize = (256, 1, 1), + ): layout => + val tid = GIO.invocationId + GIO.when(tid < sizes.numElements): + val g = GIO.read[Float16](layout.gate, tid).asFloat32 + val u = GIO.read[Float16](layout.up, tid).asFloat32 + val sigmoidG = 1.0f / (1.0f + exp(-g)) + val result = g * sigmoidG * u + GIO.write[Float16](layout.output, tid, result.asFloat16) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16TopPSampleProgram.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16TopPSampleProgram.scala new file mode 100644 index 00000000..0b69a882 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/F16TopPSampleProgram.scala @@ -0,0 +1,223 @@ +package io.computenode.cyfra.llama.programs.f16 + +import io.computenode.cyfra.core.GProgram +import io.computenode.cyfra.core.GProgram.StaticDispatch +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.dsl.binding.GShared +import io.computenode.cyfra.dsl.gio.GIO +import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} + +/** GPU-based Top-P (nucleus) sampling. + * + * Uses a multi-phase approach: + * 1. Softmax: Find max, compute exp values, sum (subgroup reductions) + * 2. Local Top: Each thread finds its best candidate + * 3. Warp reduction: Each warp finds its best candidate (8 total) + * 4. Sort: Simple bubble sort on 8 candidates + * 5. Sample: Cumulative sum over sorted candidates, sample at threshold + * + * For peaked LLM distributions, 8 warp-best candidates typically cover 90%+ probability. + * Falls back to argmax if sampling fails. + * + * Configuration: + * - BLOCK_SIZE = 256 threads (8 warps) + * - NUM_WARPS = 8 candidates for sampling + */ +object F16TopPSampleProgram: + val WARP_SIZE = 32 + val BLOCK_SIZE = 256 + val NUM_WARPS = BLOCK_SIZE / WARP_SIZE // 8 + + case class Sizes(vocabSize: Int): + def numIterations: Int = (vocabSize + BLOCK_SIZE - 1) / BLOCK_SIZE + + case class SampleParams( + temperature: Float32, + topP: Float32, + randomValue: Float32, // Pre-generated random [0, 1) + ) extends GStruct[SampleParams] + + object SampleParams: + given GStructSchema[SampleParams] = GStructSchema.derived + + case class ProgramLayout( + logits: GBuffer[Float32], + params: GUniform[SampleParams], + result: GBuffer[Int32], // Output: single sampled token index + ) derives Layout + + /** Top-p sampling with temperature scaling. + * + * @param sizes Contains vocabulary size + * @return GProgram that samples a token index from logits + */ + def forward(sizes: Sizes): GProgram[Sizes, ProgramLayout] = + val vocabSize = sizes.vocabSize + val numIterations = sizes.numIterations + + GProgram[Sizes, ProgramLayout]( + layout = _ => ProgramLayout( + logits = GBuffer[Float32](vocabSize), + params = GUniform[SampleParams](), + result = GBuffer[Int32](1), + ), + dispatch = (_, _) => StaticDispatch((1, 1, 1)), + workgroupSize = (BLOCK_SIZE, 1, 1), + ): layout => + val tid: Int32 = GIO.localInvocationId.x + val laneId: Int32 = tid.mod(WARP_SIZE) + val warpId: Int32 = tid / WARP_SIZE + + val runtimeParams = layout.params.read + val temperature: Float32 = runtimeParams.temperature + val topP: Float32 = runtimeParams.topP + val randomValue: Float32 = runtimeParams.randomValue + + // Shared memory for reductions + val sharedMax = GShared[Float32](NUM_WARPS) + val sharedSum = GShared[Float32](NUM_WARPS) + // Store warp-best candidates: 8 probs and 8 indices + val sharedProbs = GShared[Float32](NUM_WARPS) + val sharedIndices = GShared[Int32](NUM_WARPS) + + for + // ========== PHASE 1: SOFTMAX ========== + // Step 1a: Each thread finds local max + localMax <- GIO.pure { + GSeq.gen[Int32](tid, _ + BLOCK_SIZE).limit(numIterations).fold(-1e10f, (maxVal: Float32, idx: Int32) => + when(idx < vocabSize)( + max(maxVal, GIO.read[Float32](layout.logits, idx)) + ).otherwise(maxVal) + ) + } + + // Step 1b: Warp reduction for max + warpMax <- GIO.pure(GIO.subgroupMax(localMax)) + _ <- GIO.when(laneId === 0): + sharedMax.write(warpId, warpMax) + _ <- GIO.barrier + + // Step 1c: Final max reduction (first warp reads all, reduces) + globalMax <- GIO.pure { + when(tid < NUM_WARPS)(sharedMax.read(tid)).otherwise(-1e10f) + } + globalMaxWarp0 <- GIO.pure(GIO.subgroupMax(globalMax)) + // Broadcast global max to all threads via shared memory + _ <- GIO.when(tid === 0): + sharedMax.write(0, globalMaxWarp0) + _ <- GIO.barrier + globalMaxReduced <- GIO.pure(sharedMax.read(0)) + + // Step 1d: Temperature scaling factor + invTemp <- GIO.pure(when(temperature > 0.0001f)(1.0f / temperature).otherwise(1.0f)) + + // Step 1e: Each thread computes local exp sum + localSum <- GIO.pure { + GSeq.gen[Int32](tid, _ + BLOCK_SIZE).limit(numIterations).fold(0.0f, (sumVal: Float32, idx: Int32) => + when(idx < vocabSize) { + val logit: Float32 = GIO.read[Float32](layout.logits, idx) + val expVal: Float32 = exp((logit - globalMaxReduced) * invTemp) + sumVal + expVal + }.otherwise(sumVal) + ) + } + + // Step 1f: Warp reduction for sum + warpSum <- GIO.pure(GIO.subgroupAdd(localSum)) + _ <- GIO.when(laneId === 0): + sharedSum.write(warpId, warpSum) + _ <- GIO.barrier + + // Step 1g: Final sum reduction + globalSum <- GIO.pure { + when(tid < NUM_WARPS)(sharedSum.read(tid)).otherwise(0.0f) + } + globalSumWarp0 <- GIO.pure(GIO.subgroupAdd(globalSum) + 1e-10f) + // Broadcast global sum to all threads via shared memory + _ <- GIO.when(tid === 0): + sharedSum.write(0, globalSumWarp0) + _ <- GIO.barrier + globalSumReduced <- GIO.pure(sharedSum.read(0)) + rcpSum <- GIO.pure(1.0f / globalSumReduced) + + // ========== PHASE 2: LOCAL TOP CANDIDATE SELECTION ========== + // Each thread finds its best candidate (prob, idx) stored as Vec2 + localTop <- GIO.pure { + GSeq.gen[Int32](tid, _ + BLOCK_SIZE).limit(numIterations).fold( + vec2(-1e10f, -1.0f), // (maxProb, maxIdx) + (state: Vec2[Float32], idx: Int32) => + when(idx < vocabSize) { + val logit: Float32 = GIO.read[Float32](layout.logits, idx) + val prob: Float32 = exp((logit - globalMaxReduced) * invTemp) * rcpSum + when(prob > state.x)(vec2(prob, idx.asFloat)).otherwise(state) + }.otherwise(state) + ) + } + + // ========== PHASE 3: WARP-LEVEL REDUCTION TO GET TOP 8 ========== + // Each warp finds its best candidate using subgroup reductions + warpBestProb <- GIO.pure(GIO.subgroupMax(localTop.x)) + + // Tolerance-based comparison to handle floating-point precision + isWarpBest <- GIO.pure(localTop.x >= warpBestProb - 1e-7f) + + // For threads with best prob, use their index; others use large value + // subgroupMin will select the lowest index among ties + warpBestIdx <- GIO.pure(when(isWarpBest)(localTop.y).otherwise(1e10f)) + warpWinnerIdx <- GIO.pure(GIO.subgroupMin(warpBestIdx)) + + // Lane 0 of each warp writes the result + _ <- GIO.when(laneId === 0): + for + _ <- sharedProbs.write(warpId, warpBestProb) + _ <- sharedIndices.write(warpId, warpWinnerIdx.asInt) + yield GStruct.Empty() + _ <- GIO.barrier + + // ========== PHASE 4: SORT 8 CANDIDATES (single thread) ========== + // Simple bubble sort for 8 elements - very fast + _ <- GIO.when(tid === 0) { + GIO.repeat(NUM_WARPS): _ => + GIO.repeat(NUM_WARPS - 1): j => + val jIdx: Int32 = j + val jIdxPlus1: Int32 = jIdx + 1 + // Read all values BEFORE any writes to avoid the swap bug + for + prob0 <- GIO.pure(sharedProbs.read(jIdx)) + prob1 <- GIO.pure(sharedProbs.read(jIdxPlus1)) + idx0 <- GIO.pure(sharedIndices.read(jIdx)) + idx1 <- GIO.pure(sharedIndices.read(jIdxPlus1)) + // Swap if out of order (descending) + _ <- GIO.when(prob0 < prob1): + for + _ <- sharedProbs.write(jIdx, prob1) + _ <- sharedProbs.write(jIdxPlus1, prob0) + _ <- sharedIndices.write(jIdx, idx1) + _ <- sharedIndices.write(jIdxPlus1, idx0) + yield GStruct.Empty() + yield GStruct.Empty() + } + _ <- GIO.barrier + + // ========== PHASE 5: CUMULATIVE SUM + SAMPLE ========== + _ <- GIO.when(tid === 0) { + // Compute cumulative sum and sample + val threshold = randomValue * topP + val result = GSeq.gen[Int32](0, _ + 1).limit(NUM_WARPS).fold( + vec2(0.0f, -1.0f), // (cumSum, sampledIdx) + (state: Vec2[Float32], i: Int32) => + val cumSum = state.x + val sampledIdx = state.y + val prob = sharedProbs.read(i) + val idx = sharedIndices.read(i) + val newCumSum = cumSum + prob + // Only set sampledIdx if not already set AND we've exceeded threshold + val newSampledIdx = when(sampledIdx < 0.0f && newCumSum >= threshold)(idx.asFloat).otherwise(sampledIdx) + vec2(newCumSum, newSampledIdx) + ) + // If sampling didn't find a result, use argmax (position 0 after sort) + val finalIdx = when(result.y >= 0.0f)(result.y.asInt).otherwise(sharedIndices.read(0)) + GIO.write[Int32](layout.result, 0, finalIdx) + } + yield GStruct.Empty() diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/package.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/package.scala new file mode 100644 index 00000000..893c40f4 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/f16/package.scala @@ -0,0 +1,4 @@ +package io.computenode.cyfra.llama.programs + +/** F16 (Float16) precision programs for Llama inference */ +package object f16 diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/package.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/package.scala new file mode 100644 index 00000000..da5e9c56 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/programs/package.scala @@ -0,0 +1,16 @@ +package io.computenode.cyfra.llama + +import io.computenode.cyfra.dsl.Value.Int32 +import io.computenode.cyfra.dsl.struct.GStruct + +package object programs: + + /** Runtime parameters for attention and RoPE operations. + * + * Passed via uniform to support single compiled pipeline with runtime-varying positions. + * Used by RoPE (for position encoding) and attention (for KV cache operations). + */ + case class AttentionParams( + seqLen: Int32, // actual sequence length (startPos + T) + startPos: Int32, // position of first query token in full sequence + ) extends GStruct[AttentionParams] diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/tokenizer/LlamaTokenizer.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/tokenizer/LlamaTokenizer.scala new file mode 100644 index 00000000..3f6fbab6 --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/tokenizer/LlamaTokenizer.scala @@ -0,0 +1,121 @@ +package io.computenode.cyfra.llama.tokenizer + +import io.computenode.cyfra.llama.gguf.GGUFReader.GGUFFile +import scala.collection.mutable + +/** Simple BPE tokenizer for Llama models. + * + * Reads vocabulary from GGUF metadata. + * Supports encoding (text -> tokens) and decoding (tokens -> text). + */ +class LlamaTokenizer(gguf: GGUFFile): + + // Special tokens + val bosToken: Int = gguf.getInt("tokenizer.ggml.bos_token_id").getOrElse(1) + val eosToken: Int = gguf.getInt("tokenizer.ggml.eos_token_id").getOrElse(2) + val padToken: Int = gguf.getInt("tokenizer.ggml.padding_token_id").getOrElse(0) + + // Load vocabulary from GGUF + private val vocab: Array[String] = gguf.getStringArray("tokenizer.ggml.tokens").getOrElse(Array.empty) + private val scores: Array[Float] = gguf.getFloatArray("tokenizer.ggml.scores").getOrElse(Array.empty) + + // Build reverse lookup for encoding + private val tokenToId: Map[String, Int] = vocab.zipWithIndex.toMap + + /** Number of tokens in vocabulary. */ + def vocabSize: Int = vocab.length + + // GPT-style special characters used by Llama 3 + private val GPT_SPACE = '\u0120' // Ġ - space marker + private val GPT_NEWLINE = '\u010A' // Ċ - newline marker + private val GPT_TAB = '\u0109' // ĉ - tab marker + + /** Decode a single token to string. + * Handles special byte tokens like <0xNN> and GPT-style markers. + */ + def decodeToken(tokenId: Int): String = + if tokenId < 0 || tokenId >= vocab.length then + s"" + else + val token = vocab(tokenId) + // Handle byte tokens like <0xNN> + if token.startsWith("<0x") && token.endsWith(">") then + try + val byteVal = Integer.parseInt(token.drop(3).dropRight(1), 16) + new String(Array(byteVal.toByte), "UTF-8") + catch + case _: Exception => token + else + // Replace GPT-style markers with actual characters + token + .replace(GPT_SPACE.toString, " ") + .replace(GPT_NEWLINE.toString, "\n") + .replace(GPT_TAB.toString, "\t") + .replace("▁", " ") // Sentencepiece space marker + + /** Decode a sequence of tokens to string. */ + def decode(tokens: Array[Int]): String = + tokens.map(decodeToken).mkString + + // Detect which space marker the vocabulary uses (▁ for Llama 1/2, Ġ for Llama 3) + private val spaceMarker: String = + if tokenToId.contains("Ġ") || vocab.exists(_.startsWith("Ġ")) then "Ġ" + else "▁" + + /** Encode text to tokens using greedy longest-match BPE. + * + * This handles both SentencePiece (▁) and GPT (Ġ) space marker conventions: + * - Space marker represents a space before the token + * - First token of a word has space marker prefix + * + * Note: This is a simplified implementation. For production, + * use the official sentencepiece tokenizer. + */ + def encode(text: String, addBos: Boolean = true): Array[Int] = + val tokens = mutable.ArrayBuffer[Int]() + + if addBos then + tokens += bosToken + + // Replace spaces with detected space marker and prepend for start + val normalized = spaceMarker + text.replace(" ", spaceMarker) + + // Greedy longest-match tokenization + var pos = 0 + while pos < normalized.length do + var found = false + var maxLen = math.min(normalized.length - pos, 64) // Max token length + + // Try to find longest matching token + while maxLen > 0 && !found do + val candidate = normalized.substring(pos, pos + maxLen) + if tokenToId.contains(candidate) then + tokens += tokenToId(candidate) + pos += maxLen + found = true + else + maxLen -= 1 + + if !found then + // Single character fallback + val char = normalized.charAt(pos) + val charStr = char.toString + if tokenToId.contains(charStr) then + tokens += tokenToId(charStr) + else + // Unknown character - try byte fallback for UTF-8 bytes + val bytes = charStr.getBytes("UTF-8") + for b <- bytes do + val byteToken = f"<0x${b & 0xFF}%02X>" + tokenToId.get(byteToken).foreach(tokens += _) + pos += 1 + + tokens.toArray + + /** Get token string by ID (for debugging). */ + def getToken(tokenId: Int): String = + if tokenId >= 0 && tokenId < vocab.length then vocab(tokenId) + else s"" + +object LlamaTokenizer: + def apply(gguf: GGUFFile): LlamaTokenizer = new LlamaTokenizer(gguf) diff --git a/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/util/Logger.scala b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/util/Logger.scala new file mode 100644 index 00000000..2e90408d --- /dev/null +++ b/cyfra-llama/src/main/scala/io/computenode/cyfra/llama/util/Logger.scala @@ -0,0 +1,12 @@ +package io.computenode.cyfra.llama.util + +import org.slf4j.LoggerFactory + +/** Logger for the Llama module using SLF4J. */ +object Logger: + private val logger = LoggerFactory.getLogger("io.computenode.cyfra.llama") + + def info(msg: => String): Unit = if logger.isInfoEnabled then logger.info(msg) + def debug(msg: => String): Unit = if logger.isDebugEnabled then logger.debug(msg) + def warn(msg: => String): Unit = if logger.isWarnEnabled then logger.warn(msg) + def error(msg: => String): Unit = if logger.isErrorEnabled then logger.error(msg) diff --git a/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/DequantizationTest.scala b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/DequantizationTest.scala new file mode 100644 index 00000000..8bc07cd5 --- /dev/null +++ b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/DequantizationTest.scala @@ -0,0 +1,73 @@ +package io.computenode.cyfra.llama + +import io.computenode.cyfra.llama.gguf.{Dequantize => CpuDequantize} +import munit.FunSuite + +import java.nio.{ByteBuffer, ByteOrder} + +/** Tests for CPU dequantization utilities. + * + * GPU quantized matmul tests were removed with F32 pipeline. + */ +class DequantizationTest extends FunSuite: + + /** Convert float32 to fp16 (approximate, for test purposes). */ + def floatToFp16(f: Float): Short = + val bits = java.lang.Float.floatToIntBits(f) + val sign = (bits >> 31) & 1 + val exp = (bits >> 23) & 0xFF + val mant = bits & 0x7FFFFF + + if exp == 0 then + // Zero or denormalized + (sign << 15).toShort + else if exp == 255 then + // Infinity or NaN + ((sign << 15) | 0x7C00).toShort + else + val newExp = exp - 127 + 15 + if newExp <= 0 then + // Underflow to zero + (sign << 15).toShort + else if newExp >= 31 then + // Overflow to infinity + ((sign << 15) | 0x7C00).toShort + else + val newMant = mant >> 13 + ((sign << 15) | (newExp << 10) | newMant).toShort + + test("Q4_K scale extraction: verify is < 4 vs is >= 4 logic"): + // This test verifies the scale extraction matches llama.cpp's get_scale_min_k4 + // For is < 4: sc = scales[is] & 0x3F, m = scales[is+4] & 0x3F + // For is >= 4: sc = (scales[is+4] & 0x0F) | ((scales[is-4] >> 6) << 4) + // m = ((scales[is+4] >> 4) & 0x0F) | ((scales[is] >> 6) << 4) + + val scales = Array[Byte]( + 0x3f, 0x3e, 0x3d, 0x3c, // scales[0-3]: values 63, 62, 61, 60 + 0x10, 0x20, 0x30, 0x40.toByte, // scales[4-7]: mins for j<4 + 0x05, 0x06, 0x07, 0x08, // scales[8-11]: for j>=4 extraction + ) + + // Verify CPU extraction for j=0 + val (sc0, m0) = Dequantize.getScaleMinK4(0, scales) + assertEquals(sc0.toInt, 63, "scale for j=0 should be 63") + assertEquals(m0.toInt, 16, "min for j=0 should be 16") + + // Verify CPU extraction for j=4 + // sc4 = (scales[8] & 0x0F) | ((scales[0] >> 6) << 4) + // scales[8] = 0x05, scales[0] = 0x3F -> 0x3F >> 6 = 0 + // sc4 = (5 & 0x0F) | (0 << 4) = 5 + val (sc4, m4) = Dequantize.getScaleMinK4(4, scales) + assertEquals(sc4.toInt, 5, "scale for j=4 should be 5") + + // Helper method to expose getScaleMinK4 for testing + object Dequantize: + def getScaleMinK4(j: Int, scales: Array[Byte]): (Float, Float) = + if j < 4 then + val d = (scales(j) & 0x3F).toFloat + val m = (scales(j + 4) & 0x3F).toFloat + (d, m) + else + val d = ((scales(j + 4) & 0x0F) | ((scales(j - 4) >> 6) << 4)).toFloat + val m = ((scales(j + 4) >> 4) & 0x0F | ((scales(j) >> 6) << 4)).toFloat + (d, m) diff --git a/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/DirectBenchmarkTest.scala b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/DirectBenchmarkTest.scala new file mode 100644 index 00000000..7fdfcf45 --- /dev/null +++ b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/DirectBenchmarkTest.scala @@ -0,0 +1,91 @@ +package io.computenode.cyfra.llama + +import io.computenode.cyfra.llama.gguf.GGUFReader +import io.computenode.cyfra.llama.inference.LlamaInference +import io.computenode.cyfra.llama.tokenizer.LlamaTokenizer +import io.computenode.cyfra.llama.model.LlamaModel +import io.computenode.cyfra.llama.pipeline.LlamaF16Pipeline +import io.computenode.cyfra.runtime.VkCyfraRuntime +import munit.FunSuite + +import java.nio.file.{Files, Paths} +import scala.concurrent.duration.* + +/** Direct benchmark to verify which code path is actually running. */ +class DirectBenchmarkTest extends FunSuite: + + val modelPath = "cyfra-llama/Llama-3.2-1B-Instruct-f16.gguf" + + override def munitTimeout: Duration = 10.minutes + + test("Direct F16 Pipeline.generate benchmark"): + assume(Files.exists(Paths.get(modelPath)), s"Model not found: $modelPath") + + VkCyfraRuntime.using: + println("Loading model...") + val model = LlamaModel.fromGGUF(Paths.get(modelPath)) + + try + println("Creating inference...") + val inference = new LlamaInference(model, maxT = 2048) + val f16Pipeline = inference.getF16Pipeline + + val tokenizer = LlamaTokenizer(model.gguf) + val promptText = "Once upon a time" + val promptTokens = tokenizer.encode(promptText) + + println("\n" + "=" * 60) + println(" Llama 3.2 1B F16 - KV Cache Benchmark (Cyfra GPU)") + println("=" * 60) + + // Warmup - 3 generations to ensure everything is compiled and cached + println("\nWarming up (3 generations)...") + for i <- 1 to 3 do + f16Pipeline.generate(promptTokens, 20, temperature = 0.0f) + println(s" warmup $i done") + + // Benchmark with longer generation + val maxTokens = 128 + println(s"\n--- Benchmark: $maxTokens tokens ---") + println(s"Prompt: '$promptText'\n") + + // Timed generation with output + print("Output: ") + val start = System.nanoTime() + val generated = f16Pipeline.generate( + promptTokens = promptTokens, + maxNewTokens = maxTokens, + temperature = 0.0f, // greedy + onToken = token => print(tokenizer.decodeToken(token)), + stopTokens = Set(tokenizer.eosToken, 128009), + ) + val elapsed = (System.nanoTime() - start) / 1e6 + println("\n") + + val tokPerSec = generated.length * 1000.0 / elapsed + println(s"Generated: ${generated.length} tokens") + println(s"Time: ${elapsed.toInt} ms") + println(f"Throughput: $tokPerSec%.1f tok/s") + + // Multiple runs for consistency + println(s"\n--- Consistency check (5 runs x $maxTokens tokens) ---") + val times = (1 to 5).map: i => + val runStart = System.nanoTime() + val tokens = f16Pipeline.generate(promptTokens, maxTokens, temperature = 0.0f) + val runElapsed = (System.nanoTime() - runStart) / 1e6 + val runTokPerSec = tokens.length * 1000.0 / runElapsed + println(f" Run $i: ${tokens.length} tokens in ${runElapsed.toInt}%5d ms = $runTokPerSec%.1f tok/s") + runElapsed + + val avgTime = times.sum / times.length + val avgTokPerSec = maxTokens * 1000.0 / avgTime + val minTime = times.min + val maxTokPerSec = maxTokens * 1000.0 / minTime + + println("\n" + "=" * 60) + println(f" Average: $avgTokPerSec%.1f tok/s") + println(f" Best: $maxTokPerSec%.1f tok/s") + println("=" * 60) + + finally + model.close() diff --git a/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/F16KVCacheTest.scala b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/F16KVCacheTest.scala new file mode 100644 index 00000000..c69b1d28 --- /dev/null +++ b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/F16KVCacheTest.scala @@ -0,0 +1,92 @@ +package io.computenode.cyfra.llama + +import io.computenode.cyfra.llama.inference.LlamaInference +import io.computenode.cyfra.llama.model.LlamaModel +import io.computenode.cyfra.llama.tokenizer.LlamaTokenizer +import io.computenode.cyfra.runtime.VkCyfraRuntime +import munit.FunSuite + +import java.nio.file.{Files, Paths} +import scala.concurrent.duration.* + +/** Tests for F16 KV Cache Pipeline with Vec4-optimized matmuls. + * + * This tests LlamaF16Pipeline which provides O(1) per-token inference + * by maintaining an F16 KV cache on GPU. Uses Vec4 weight loads for 4x bandwidth. + */ +class F16KVCacheTest extends FunSuite: + + val modelPath = Paths.get("cyfra-llama/Llama-3.2-1B-Instruct-f16.gguf") + + override def munitTimeout: Duration = 15.minutes + + test("F16 KV Cache Pipeline - longer generation benchmark"): + assume(Files.exists(modelPath), s"Model not found: $modelPath") + + VkCyfraRuntime.using: + println("\n" + "=" * 70) + println(" F16 KV Cache Pipeline - Performance Benchmark") + println("=" * 70) + + val model = LlamaModel.fromGGUF(modelPath) + val tokenizer = LlamaTokenizer(model.gguf) + + try + val inference = new LlamaInference(model, maxT = 2048) + val f16Pipeline = inference.getF16Pipeline + + val promptText = "Here is a Python server that creates a new user in the database and the repository:" + val promptTokens = tokenizer.encode(promptText) + val maxTokens = 1000 + + println(s"\nPrompt: '$promptText'") + println(s"Generating $maxTokens tokens...") + + // Warmup + println("Warming up (1 generation)...") + f16Pipeline.generate(promptTokens, 100, temperature = 0.2f) + + + // Benchmark + println(s"\n--- Benchmark: $maxTokens tokens ---\n") + print("Output: " + promptText) + val wallStart = System.nanoTime() + val generated = f16Pipeline.generate( + promptTokens = promptTokens, + maxNewTokens = maxTokens, + temperature = 0.2f, + topP = 0.9f, + onToken = token => print(tokenizer.decodeToken(token)), + stopTokens = Set(tokenizer.eosToken, 128009), // EOS + end-of-turn + reportStats = true, // Print GPU execution timing + ) + val wallElapsed = (System.nanoTime() - wallStart) / 1e6 + println("\n") + + val wallTokPerSec = generated.length * 1000.0 / wallElapsed + println(f"Wall-clock time: ${wallElapsed.toInt} ms ($wallTokPerSec%.2f tok/s)") + + // Multiple runs - GPU-only timing + println(s"\n--- Consistency check (5 runs x $maxTokens tokens) ---") + val gpuTimes = (1 to 5).map: i => + f16Pipeline.generate( + promptTokens = promptTokens, + maxNewTokens = maxTokens, + temperature = 0.2f, + stopTokens = Set(tokenizer.eosToken, 128009), + ) + val stats = f16Pipeline.lastStats.get + println(f" Run $i: ${stats.generatedTokens} tokens, decode ${stats.decodeTimeMs.toInt}%5d ms = ${stats.decodeTokPerSec}%.2f tok/s (prefill ${stats.prefillTimeMs.toInt} ms)") + stats.decodeTokPerSec + + val avgDecodeTokPerSec = gpuTimes.sum / gpuTimes.length + + println("\n" + "=" * 70) + println(f" Average decode throughput: $avgDecodeTokPerSec%.2f tok/s") + println(" Memory usage: ~50% of F32 (F16 weights + F16 KV cache)") + println("=" * 70) + + finally + model.close() + +end F16KVCacheTest diff --git a/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/F16TopPSampleTest.scala b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/F16TopPSampleTest.scala new file mode 100644 index 00000000..c3a86cc3 --- /dev/null +++ b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/F16TopPSampleTest.scala @@ -0,0 +1,368 @@ +package io.computenode.cyfra.llama + +import io.computenode.cyfra.core.GBufferRegion +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.llama.programs.f16.F16TopPSampleProgram +import io.computenode.cyfra.llama.programs.f16.F16TopPSampleProgram.{ProgramLayout, SampleParams, Sizes} +import io.computenode.cyfra.runtime.VkCyfraRuntime +import munit.FunSuite + +import java.nio.{ByteBuffer, ByteOrder} + +/** Tests for GPU-based top-p sampling. + * + * Verifies that the GPU sampler produces correct results for: + * - Low temperature (near-greedy) decoding + * - Top-p sampling with peaked distributions + * - Top-p sampling with different random values + */ +class F16TopPSampleTest extends FunSuite: + + val SMALL_VOCAB = 1024 + val LARGE_VOCAB = 32000 // Typical LLM size + + /** CPU reference implementation for top-p sampling */ + def cpuTopPSample(logits: Array[Float], temperature: Float, topP: Float, randomValue: Float): Int = + val scaled = logits.map(_ / math.max(temperature, 0.0001f)) + val maxLogit = scaled.max + val expLogits = scaled.map(x => math.exp(x - maxLogit).toFloat) + val sumExp = expLogits.sum + val probs = expLogits.map(_ / sumExp) + + val indexed = probs.zipWithIndex.sortBy(-_._1) + + var cumSum = 0.0f + var cutoffIdx = 0 + while cutoffIdx < indexed.length && cumSum < topP do + cumSum += indexed(cutoffIdx)._1 + cutoffIdx += 1 + + val topTokens = indexed.take(cutoffIdx) + val topSum = topTokens.map(_._1).sum + val threshold = randomValue * topSum + + var acc = 0.0f + var result = topTokens.last._2 + for (prob, idx) <- topTokens do + acc += prob + if acc >= threshold && result == topTokens.last._2 then result = idx + + result + + def allocateBuffer(size: Int): ByteBuffer = + ByteBuffer.allocateDirect(size * 4).order(ByteOrder.nativeOrder()) + + def copyToBuffer(arr: Array[Float], buf: ByteBuffer): Unit = + buf.clear() + buf.asFloatBuffer().put(arr) + buf.rewind() + + test("Low temperature decoding returns argmax"): + VkCyfraRuntime.using: + val vocabSize = SMALL_VOCAB + val sizes = Sizes(vocabSize) + + val logits = Array.fill(vocabSize)(-10.0f) + val expectedMaxIdx = 42 + logits(expectedMaxIdx) = 5.0f + logits(100) = 2.0f + logits(200) = 1.0f + + val logitsBuf = allocateBuffer(vocabSize) + copyToBuffer(logits, logitsBuf) + val resultBuf = allocateBuffer(1) + + // Very low temperature + low random value = greedy + val paramsBuf = ByteBuffer.allocateDirect(16).order(ByteOrder.nativeOrder()) + paramsBuf.putFloat(0.001f) // very low temperature + paramsBuf.putFloat(0.99f) // high topP + paramsBuf.putFloat(0.01f) // low random value + paramsBuf.putFloat(0.0f) // padding + paramsBuf.rewind() + + val program = F16TopPSampleProgram.forward(sizes) + val result = new Array[Int](1) + + val region = GBufferRegion + .allocate[ProgramLayout] + .map(layout => program.execute(sizes, layout)) + + region.runUnsafe( + init = ProgramLayout( + logits = GBuffer[Float32](logitsBuf), + params = GUniform[SampleParams](paramsBuf), + result = GBuffer[Int32](resultBuf), + ), + onDone = layout => + layout.result.read(resultBuf) + resultBuf.rewind() + result(0) = resultBuf.asIntBuffer().get(0), + ) + + assertEquals(result(0), expectedMaxIdx, s"Low-temp sampling should return argmax $expectedMaxIdx") + + test("Top-p sampling with strongly peaked distribution"): + VkCyfraRuntime.using: + val vocabSize = SMALL_VOCAB + val sizes = Sizes(vocabSize) + + val logits = Array.fill(vocabSize)(-10.0f) + val dominantIdx = 123 + logits(dominantIdx) = 10.0f + logits(200) = 0.0f + logits(300) = -1.0f + + val logitsBuf = allocateBuffer(vocabSize) + copyToBuffer(logits, logitsBuf) + val resultBuf = allocateBuffer(1) + + val temperature = 1.0f + val topP = 0.9f + val randomValue = 0.5f + + val paramsBuf = ByteBuffer.allocateDirect(16).order(ByteOrder.nativeOrder()) + paramsBuf.putFloat(temperature) + paramsBuf.putFloat(topP) + paramsBuf.putFloat(randomValue) + paramsBuf.putFloat(0.0f) + paramsBuf.rewind() + + val program = F16TopPSampleProgram.forward(sizes) + val result = new Array[Int](1) + + val region = GBufferRegion + .allocate[ProgramLayout] + .map(layout => program.execute(sizes, layout)) + + region.runUnsafe( + init = ProgramLayout( + logits = GBuffer[Float32](logitsBuf), + params = GUniform[SampleParams](paramsBuf), + result = GBuffer[Int32](resultBuf), + ), + onDone = layout => + layout.result.read(resultBuf) + resultBuf.rewind() + result(0) = resultBuf.asIntBuffer().get(0), + ) + + assertEquals(result(0), dominantIdx, "With peaked distribution should select dominant token") + + test("Top-p sampling respects random value"): + VkCyfraRuntime.using: + val vocabSize = SMALL_VOCAB + val sizes = Sizes(vocabSize) + + val logits = Array.fill(vocabSize)(-10.0f) + logits(10) = 2.0f // Highest + logits(20) = 1.9f + logits(30) = 1.8f + + val logitsBuf = allocateBuffer(vocabSize) + copyToBuffer(logits, logitsBuf) + val resultBuf = allocateBuffer(1) + + val temperature = 1.0f + val topP = 0.99f + val lowRandomValue = 0.1f + + val paramsBuf = ByteBuffer.allocateDirect(16).order(ByteOrder.nativeOrder()) + paramsBuf.putFloat(temperature) + paramsBuf.putFloat(topP) + paramsBuf.putFloat(lowRandomValue) + paramsBuf.putFloat(0.0f) + paramsBuf.rewind() + + val program = F16TopPSampleProgram.forward(sizes) + val result = new Array[Int](1) + + val region = GBufferRegion + .allocate[ProgramLayout] + .map(layout => program.execute(sizes, layout)) + + region.runUnsafe( + init = ProgramLayout( + logits = GBuffer[Float32](logitsBuf), + params = GUniform[SampleParams](paramsBuf), + result = GBuffer[Int32](resultBuf), + ), + onDone = layout => + layout.result.read(resultBuf) + resultBuf.rewind() + result(0) = resultBuf.asIntBuffer().get(0), + ) + + assertEquals(result(0), 10, "Low random value should select highest probability token") + + test("GPU sampling matches CPU reference for peaked distribution"): + VkCyfraRuntime.using: + val vocabSize = SMALL_VOCAB + val sizes = Sizes(vocabSize) + + val random = new scala.util.Random(42) + val logits = Array.fill(vocabSize)(random.nextFloat() * 2 - 1) + logits(50) = 8.0f // Dominant + logits(100) = 5.0f + logits(150) = 4.0f + + val temperature = 0.8f + val topP = 0.9f + val randomValue = 0.3f + + val cpuResult = cpuTopPSample(logits, temperature, topP, randomValue) + + val logitsBuf = allocateBuffer(vocabSize) + copyToBuffer(logits, logitsBuf) + val resultBuf = allocateBuffer(1) + + val paramsBuf = ByteBuffer.allocateDirect(16).order(ByteOrder.nativeOrder()) + paramsBuf.putFloat(temperature) + paramsBuf.putFloat(topP) + paramsBuf.putFloat(randomValue) + paramsBuf.putFloat(0.0f) + paramsBuf.rewind() + + val program = F16TopPSampleProgram.forward(sizes) + val gpuResult = new Array[Int](1) + + val region = GBufferRegion + .allocate[ProgramLayout] + .map(layout => program.execute(sizes, layout)) + + region.runUnsafe( + init = ProgramLayout( + logits = GBuffer[Float32](logitsBuf), + params = GUniform[SampleParams](paramsBuf), + result = GBuffer[Int32](resultBuf), + ), + onDone = layout => + layout.result.read(resultBuf) + resultBuf.rewind() + gpuResult(0) = resultBuf.asIntBuffer().get(0), + ) + + println(s"CPU result: $cpuResult, GPU result: ${gpuResult(0)}") + assertEquals(gpuResult(0), 50, "GPU should select highest probability token") + + test("Large vocabulary sampling"): + VkCyfraRuntime.using: + val vocabSize = LARGE_VOCAB + val sizes = Sizes(vocabSize) + + val logits = Array.fill(vocabSize)(0.0f) + val expectedMaxIdx = 28756 + logits(expectedMaxIdx) = 10.0f + + val logitsBuf = allocateBuffer(vocabSize) + copyToBuffer(logits, logitsBuf) + val resultBuf = allocateBuffer(1) + + // Low temp + low random = should get argmax + val paramsBuf = ByteBuffer.allocateDirect(16).order(ByteOrder.nativeOrder()) + paramsBuf.putFloat(0.001f) // very low temperature + paramsBuf.putFloat(0.99f) + paramsBuf.putFloat(0.01f) // low random value + paramsBuf.putFloat(0.0f) + paramsBuf.rewind() + + val program = F16TopPSampleProgram.forward(sizes) + val result = new Array[Int](1) + + val region = GBufferRegion + .allocate[ProgramLayout] + .map(layout => program.execute(sizes, layout)) + + region.runUnsafe( + init = ProgramLayout( + logits = GBuffer[Float32](logitsBuf), + params = GUniform[SampleParams](paramsBuf), + result = GBuffer[Int32](resultBuf), + ), + onDone = layout => + layout.result.read(resultBuf) + resultBuf.rewind() + result(0) = resultBuf.asIntBuffer().get(0), + ) + + assertEquals(result(0), expectedMaxIdx, s"Large vocab should find token at $expectedMaxIdx") + + test("Benchmark: GPU sampling vs CPU"): + VkCyfraRuntime.using: + val vocabSize = LARGE_VOCAB + val sizes = Sizes(vocabSize) + val numIterations = 100 + + val random = new scala.util.Random(123) + val logits = Array.fill(vocabSize)(random.nextFloat() * 4 - 2) + logits(random.nextInt(vocabSize)) = 10.0f + + val temperature = 0.8f + val topP = 0.9f + + // Benchmark CPU + val cpuStart = System.nanoTime() + for _ <- 1 to numIterations do cpuTopPSample(logits, temperature, topP, random.nextFloat()) + val cpuTimeMs = (System.nanoTime() - cpuStart) / 1e6 + + // Setup GPU + val logitsBuf = allocateBuffer(vocabSize) + copyToBuffer(logits, logitsBuf) + val resultBuf = allocateBuffer(1) + val paramsBuf = ByteBuffer.allocateDirect(16).order(ByteOrder.nativeOrder()) + + val program = F16TopPSampleProgram.forward(sizes) + + // Warmup GPU + for _ <- 1 to 10 do + paramsBuf.clear() + paramsBuf.putFloat(temperature) + paramsBuf.putFloat(topP) + paramsBuf.putFloat(random.nextFloat()) + paramsBuf.putFloat(0.0f) + paramsBuf.rewind() + + val region = GBufferRegion + .allocate[ProgramLayout] + .map(layout => program.execute(sizes, layout)) + + region.runUnsafe( + init = ProgramLayout( + logits = GBuffer[Float32](logitsBuf), + params = GUniform[SampleParams](paramsBuf), + result = GBuffer[Int32](resultBuf), + ), + onDone = _ => (), + ) + + // Benchmark GPU + val gpuStart = System.nanoTime() + for _ <- 1 to numIterations do + paramsBuf.clear() + paramsBuf.putFloat(temperature) + paramsBuf.putFloat(topP) + paramsBuf.putFloat(random.nextFloat()) + paramsBuf.putFloat(0.0f) + paramsBuf.rewind() + + val region = GBufferRegion + .allocate[ProgramLayout] + .map(layout => program.execute(sizes, layout)) + + region.runUnsafe( + init = ProgramLayout( + logits = GBuffer[Float32](logitsBuf), + params = GUniform[SampleParams](paramsBuf), + result = GBuffer[Int32](resultBuf), + ), + onDone = _ => (), + ) + val gpuTimeMs = (System.nanoTime() - gpuStart) / 1e6 + + println(s"\n--- Sampling Benchmark ($numIterations iterations, vocab=$vocabSize) ---") + println(f"CPU top-p: ${cpuTimeMs}%.2f ms total (${cpuTimeMs / numIterations}%.3f ms/sample)") + println(f"GPU top-p: ${gpuTimeMs}%.2f ms total (${gpuTimeMs / numIterations}%.3f ms/sample)") + println(f"Speedup: ${cpuTimeMs / gpuTimeMs}%.2fx") + + assert(gpuTimeMs < cpuTimeMs * 2, "GPU should not be much slower than CPU") + +end F16TopPSampleTest diff --git a/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/GGUFTest.scala b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/GGUFTest.scala new file mode 100644 index 00000000..456597a0 --- /dev/null +++ b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/GGUFTest.scala @@ -0,0 +1,90 @@ +package io.computenode.cyfra.llama + +import io.computenode.cyfra.llama.model.LlamaModel +import io.computenode.cyfra.llama.gguf.GGUFReader +import munit.FunSuite + +import java.nio.file.{Files, Path, Paths} + +class GGUFTest extends FunSuite: + // Set this to the path of a GGUF model file for testing + val testModelPath: Path = Paths.get( + sys.env.getOrElse("LLAMA_MODEL_PATH", "models/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf") + ) + + test("parse GGUF header and metadata"): + assume(Files.exists(testModelPath), s"Model file not found at $testModelPath") + + val gguf = GGUFReader.read(testModelPath) + try + println(s"GGUF Version: ${gguf.version}") + println(s"Tensors: ${gguf.tensors.size}") + println(s"Data offset: ${gguf.dataOffset}") + println() + + println("Metadata keys:") + gguf.metadata.keys.toSeq.sorted.foreach(k => println(s" $k")) + println() + + println("Architecture:") + gguf.getString("general.architecture").foreach(v => println(s" $v")) + + println("\nModel parameters:") + val arch = gguf.getString("general.architecture").getOrElse("llama") + gguf.getInt(s"$arch.embedding_length").foreach(v => println(s" embedding_length: $v")) + gguf.getInt(s"$arch.feed_forward_length").foreach(v => println(s" feed_forward_length: $v")) + gguf.getInt(s"$arch.attention.head_count").foreach(v => println(s" head_count: $v")) + gguf.getInt(s"$arch.attention.head_count_kv").foreach(v => println(s" head_count_kv: $v")) + gguf.getInt(s"$arch.block_count").foreach(v => println(s" block_count: $v")) + gguf.getInt(s"$arch.vocab_size").foreach(v => println(s" vocab_size: $v")) + gguf.getInt(s"$arch.context_length").foreach(v => println(s" context_length: $v")) + println() + + println("Tensors (first 30):") + gguf.tensors.take(30).foreach: t => + println(s" ${t.name}: shape=${t.shape.mkString("x")}, type=${t.quantType}, offset=${t.offset}") + + assert(gguf.tensors.nonEmpty) + finally + gguf.close() + + test("load LlamaModel from GGUF"): + assume(Files.exists(testModelPath), s"Model file not found at $testModelPath") + + val model = LlamaModel.fromGGUF(testModelPath) + try + model.logInfo() + + // Verify config was extracted correctly + println(s"\nExtracted config:") + println(s" hiddenSize: ${model.config.hiddenSize}") + println(s" intermediateSize: ${model.config.intermediateSize}") + println(s" numAttentionHeads: ${model.config.numAttentionHeads}") + println(s" numKeyValueHeads: ${model.config.numKeyValueHeads}") + println(s" numHiddenLayers: ${model.config.numHiddenLayers}") + println(s" vocabSize: ${model.config.vocabSize}") + println(s" maxPositionEmbeddings: ${model.config.maxPositionEmbeddings}") + println(s" headSize: ${model.config.headSize}") + println(s" gqaRatio: ${model.config.gqaRatio}") + println(s" ropeTheta: ${model.config.ropeTheta}") + + // Verify expected tensor names exist + val expectedTensors = Seq( + LlamaModel.TensorNames.tokenEmbed, + LlamaModel.TensorNames.outputNorm, + LlamaModel.TensorNames.attnNorm(0), + LlamaModel.TensorNames.attnQ(0), + LlamaModel.TensorNames.ffnNorm(0), + LlamaModel.TensorNames.ffnGate(0), + ) + + println(s"\nChecking expected tensors:") + expectedTensors.foreach: name => + model.getTensor(name) match + case Some(t) => println(s" ✓ $name: ${t.shape.mkString("x")} (${t.quantType})") + case None => println(s" ✗ $name: NOT FOUND") + + assert(model.config.hiddenSize > 0) + assert(model.config.numHiddenLayers > 0) + finally + model.close() diff --git a/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/ShaderDumpTest.scala b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/ShaderDumpTest.scala new file mode 100644 index 00000000..f7fcc9f6 --- /dev/null +++ b/cyfra-llama/src/test/scala/io/computenode/cyfra/llama/ShaderDumpTest.scala @@ -0,0 +1,158 @@ +package io.computenode.cyfra.llama + +import io.computenode.cyfra.dsl.{*, given} +import io.computenode.cyfra.core.{GProgram, GioProgram} +import io.computenode.cyfra.core.layout.Layout +import io.computenode.cyfra.spirv.compilers.DSLCompiler +import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvDisassembler, SpirvTool, SpirvToolsRunner, SpirvValidator} +import io.computenode.cyfra.llama.programs.f16.* +import munit.FunSuite + +import java.nio.ByteBuffer +import java.nio.file.{Files, Path, Paths} + +/** Dumps GPU program shaders to SPIR-V assembly and GLSL for inspection. + * + * This is useful for: + * - Comparing generated code with llama.cpp shaders + * - Verifying optimizations (Vec4 loads, unrolling, subgroup ops) + * - Debugging shader compilation issues + */ +class ShaderDumpTest extends FunSuite: + + val outputDir = Paths.get("cyfra-llama/output/shaders") + + override def beforeAll(): Unit = + Files.createDirectories(outputDir) + + // ============ F16 Programs ============ + + test("Dump F16 Embedding shader"): + val program = F16EmbeddingProgram.forward(F16EmbeddingProgram.Sizes( + seqLen = 2, + hiddenSize = 2048, + vocabSize = 32000, + )) + dumpProgram("f16_embedding", program) + + test("Dump F16 RMSNorm shader"): + val program = F16RMSNormProgram.forward(F16RMSNormProgram.Sizes( + numRows = 1, + rowSize = 2048, + eps = 1e-6f, + )) + dumpProgram("f16_rmsnorm", program) + + test("Dump F16 Fused RoPE shader"): + val program = F16FusedRoPEProgram.forward(F16FusedRoPEProgram.Sizes( + B = 1, + T = 2, + numHeadsQ = 32, + numHeadsK = 8, + headSize = 64, + theta = 10000f, + )) + dumpProgram("f16_fused_rope", program) + + test("Dump F16 MatmulVec Hybrid shader (Vec4 weights, scalar input)"): + val program = F16MatmulVecHybridProgram.forward(F16MatmulVecHybridProgram.Sizes( + batchSize = 1, + inFeatures = 2048, + outFeatures = 2048, + )) + dumpProgram("f16_matmul_vec_hybrid", program) + + test("Dump F16 MatmulVec Hybrid shader (Vec4 weights, Vec4 input)"): + val program = F16MatmulVecHybridProgram.forwardVec4(F16MatmulVecHybridProgram.Sizes( + batchSize = 1, + inFeatures = 2048, + outFeatures = 2048, + )) + dumpProgram("f16_matmul_vec4_hybrid", program) + + test("Dump F16 SwiGLU shader"): + val program = F16SwiGLUProgram.forward(F16SwiGLUProgram.Sizes(5632)) + dumpProgram("f16_swiglu", program) + + test("Dump F16 ResidualAdd shader"): + val program = F16ResidualAddProgram.forward(F16ResidualAddProgram.Sizes(2048)) + dumpProgram("f16_residual_add", program) + + test("Dump F16 Output Vec4 shader"): + val program = F16OutputVec4Program.forward(F16OutputVec4Program.Sizes( + batchSize = 1, + hiddenSize = 2048, + vocabSize = 32000, + )) + dumpProgram("f16_output_vec4", program) + + // ============ Split Attention Programs ============ + + test("Dump F16 Attention Scores shader"): + val program = F16AttentionScoresProgram.forward(F16AttentionScoresProgram.Sizes( + B = 1, + T = 1, + NH = 32, + NKV = 8, + headSize = 64, + maxSeqLen = 2048, + kCacheLayerOffset = 0, + L = 16, + )) + dumpProgram("f16_attention_scores", program) + + test("Dump F16 Attention Softmax shader"): + val program = F16AttentionSoftmaxProgram.forward(F16AttentionSoftmaxProgram.Sizes( + B = 1, + T = 1, + NH = 32, + maxSeqLen = 2048, + )) + dumpProgram("f16_attention_softmax", program) + + test("Dump F16 Attention Output shader"): + val program = F16AttentionOutputProgram.forward(F16AttentionOutputProgram.Sizes( + B = 1, + T = 1, + NH = 32, + NKV = 8, + headSize = 64, + maxSeqLen = 2048, + vCacheLayerOffset = 0, + L = 16, + )) + dumpProgram("f16_attention_output", program) + + test("Dump F16 TopP Sample shader"): + val program = F16TopPSampleProgram.forward(F16TopPSampleProgram.Sizes( + vocabSize = 32000, + )) + dumpProgram("f16_top_p_sample", program) + + private def dumpProgram[P, L: Layout](name: String, program: GProgram[P, L]): Unit = + program match + case gioProgram: GioProgram[P, L] => + val layout = summon[Layout[L]] + val bindings = layout.toBindings(layout.layoutRef).toList + val shaderCode = DSLCompiler.compile(gioProgram.body(layout.layoutRef), bindings, gioProgram.workgroupSize) + + // Create runner with file outputs + val runner = SpirvToolsRunner( + validator = SpirvValidator.Enable(throwOnFail = false), + disassembler = SpirvDisassembler.Enable( + throwOnFail = false, + toolOutput = SpirvTool.ToFile(outputDir.resolve(s"$name.spvasm"), hashSuffix = false), + settings = Seq(), + ), + crossCompilation = SpirvCross.Enable( + throwOnFail = false, + toolOutput = SpirvTool.ToFile(outputDir.resolve(s"$name.glsl"), hashSuffix = false), + settings = Seq(SpirvTool.Param("--vulkan-semantics")), + ), + originalSpirvOutput = SpirvTool.ToFile(outputDir.resolve(s"$name.spv"), hashSuffix = false), + ) + + runner.processShaderCodeWithSpirvTools(shaderCode) + println(s"Dumped $name shaders to $outputDir") + case _ => + println(s"Cannot dump $name - not a GioProgram") diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala index 9e86560b..9f3e849b 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala @@ -11,9 +11,12 @@ import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform} import io.computenode.cyfra.dsl.struct.{GStruct, GStructSchema} import io.computenode.cyfra.runtime.ExecutionHandler.{ BindingLogicError, + BufferCopyCall, + BufferCopyStep, Dispatch, DispatchType, ExecutionBinding, + ExecutionCall, ExecutionStep, PipelineBarrier, ShaderCall, @@ -22,15 +25,16 @@ import io.computenode.cyfra.runtime.ExecutionHandler.DispatchType.* import io.computenode.cyfra.runtime.ExecutionHandler.ExecutionBinding.{BufferBinding, UniformBinding} import io.computenode.cyfra.utility.Utility.timed import io.computenode.cyfra.vulkan.{VulkanContext, VulkanThreadContext} -import io.computenode.cyfra.vulkan.command.{CommandPool, Fence} +import io.computenode.cyfra.vulkan.command.{CommandPool, Fence, Semaphore} import io.computenode.cyfra.vulkan.compute.ComputePipeline import io.computenode.cyfra.vulkan.core.Queue import io.computenode.cyfra.vulkan.memory.{DescriptorPool, DescriptorPoolManager, DescriptorSet, DescriptorSetManager} import io.computenode.cyfra.vulkan.util.Util.{check, pushStack} import izumi.reflect.Tag import org.lwjgl.vulkan.VK10.* -import org.lwjgl.vulkan.VK13.{VK_ACCESS_2_SHADER_READ_BIT, VK_ACCESS_2_SHADER_WRITE_BIT, VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT, vkCmdPipelineBarrier2} -import org.lwjgl.vulkan.{VK13, VkCommandBuffer, VkCommandBufferBeginInfo, VkDependencyInfo, VkMemoryBarrier2, VkSubmitInfo} +import org.lwjgl.vulkan.VK13.{VK_ACCESS_2_SHADER_READ_BIT, VK_ACCESS_2_SHADER_WRITE_BIT, VK_ACCESS_2_TRANSFER_READ_BIT, VK_ACCESS_2_TRANSFER_WRITE_BIT, VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_2_TRANSFER_BIT, vkCmdPipelineBarrier2} +import org.lwjgl.vulkan.EXTDebugUtils.{vkCmdBeginDebugUtilsLabelEXT, vkCmdEndDebugUtilsLabelEXT} +import org.lwjgl.vulkan.{VkBufferCopy, VkCommandBuffer, VkCommandBufferBeginInfo, VkDebugUtilsLabelEXT, VkDependencyInfo, VkMemoryBarrier2, VkSubmitInfo, VkTimelineSemaphoreSubmitInfo} import scala.collection.mutable @@ -39,47 +43,138 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte private val dsManager: DescriptorSetManager = threadContext.descriptorSetManager private val commandPool: CommandPool.Reset = threadContext.commandPool + private val queue = commandPool.queue + + // Timeline semaphore for GPU-GPU synchronization (no CPU blocking between submissions) + private val timelineSemaphore = new Semaphore() + private var semaphoreValue: Long = 0 + + // Full execution cache - caches command buffer, descriptor sets, and result + // Keyed by (execution identity, layout bindings identity hash) + private case class CachedExecution( + resultBindings: Seq[GBinding[?]], + commandBuffer: VkCommandBuffer, + executeSteps: Seq[ExecutionStep], + var lastSemaphoreValue: Long, // Track which semaphore value this execution signals + var reuseCount: Int = 0, // Track reuse count for first-use sync + ) + private val executionCache = mutable.Map[(Int, Int), CachedExecution]() def handle[Params, EL: Layout, RL: Layout](execution: GExecution[Params, EL, RL], params: Params, layout: EL)(using VkAllocation): RL = - val (result, shaderCalls) = interpret(execution, params, layout) - - val descriptorSets = shaderCalls.map: - case ShaderCall(pipeline, layout, _) => - pipeline.pipelineLayout.sets - .map(dsManager.allocate) - .zip(layout) - .map: - case (set, bindings) => - set.update(bindings.map(x => VkAllocation.getUnderlying(x.binding).buffer)) - set - - val dispatches: Seq[Dispatch] = shaderCalls - .zip(descriptorSets) - .map: - case (ShaderCall(pipeline, layout, dispatch), sets) => - Dispatch(pipeline, layout, sets, dispatch) - - val (executeSteps, _) = dispatches.foldLeft((Seq.empty[ExecutionStep], Set.empty[GBinding[?]])): - case ((steps, dirty), step) => - val bindings = step.layout.flatten.map(_.binding) - if bindings.exists(dirty.contains) then (steps.appendedAll(Seq(PipelineBarrier, step)), bindings.toSet) - else (steps.appended(step), dirty ++ bindings) - - val commandBuffer = recordCommandBuffer(executeSteps) - val cleanup = () => - descriptorSets.flatten.foreach(dsManager.free) - commandPool.freeCommandBuffer(commandBuffer) - - val externalBindings = getAllBindings(executeSteps).map(VkAllocation.getUnderlying) - val deps = externalBindings.flatMap(_.execution.fold(Seq(_), _.toSeq)) - val pe = new PendingExecution(commandBuffer, deps, cleanup) - summon[VkAllocation].addExecution(pe) - externalBindings.foreach(_.execution = Left(pe)) // TODO we assume all accesses are read-write - result - - private def interpret[Params, EL: Layout, RL: Layout](execution: GExecution[Params, EL, RL], params: Params, layout: EL)(using + val layoutBindings = Layout[EL].toBindings(layout) + val layoutHash = layoutBindings.map(System.identityHashCode).hashCode() + val cacheKey = (System.identityHashCode(execution), layoutHash) + + executionCache.get(cacheKey) match + case Some(cached) => + // On first reuse, sync any pending writes (e.g., uniform buffer updates) + // This handles the case where initial uniform values differ from runtime values + if cached.reuseCount == 0 then + summon[VkAllocation].submitLayout(layout) + cached.reuseCount += 1 + + // Cache hit - submit with timeline semaphore (GPU-GPU sync, no CPU wait) + val waitValue = cached.lastSemaphoreValue + semaphoreValue += 1 + val signalValue = semaphoreValue + + submitWithSemaphore(cached.commandBuffer, waitValue, signalValue) + cached.lastSemaphoreValue = signalValue + + Layout[RL].fromBindings(cached.resultBindings) + + case None => + // Cache miss - full execution path + val (result, executionCalls) = interpret(execution, params, layout) + + // Convert ExecutionCalls to ExecutionSteps + val initialSteps: Seq[ExecutionStep] = executionCalls.map: + case ShaderCall(pipeline, layout, dispatch) => + val sets = pipeline.pipelineLayout.sets + .map(dsManager.allocate) + .zip(layout) + .map: + case (set, bindings) => + set.update(bindings.map(x => VkAllocation.getUnderlying(x.binding).buffer)) + set + Dispatch(pipeline, layout, sets, dispatch) + case BufferCopyCall(src, dst, sizeBytes) => + BufferCopyStep(src, dst, sizeBytes) + + val (executeSteps, _) = initialSteps.zipWithIndex.foldLeft((Seq.empty[ExecutionStep], Set.empty[GBinding[?]])): + case ((steps, dirty), (step, idx)) => + // Extract bindings by operation type + val (allBindings, writtenBindings) = step match + case Dispatch(_, layout, _, _) => + val allBindingsWithOp = layout.flatten + (allBindingsWithOp.map(_.binding), allBindingsWithOp.filter(b => b.operation == Operation.Write || b.operation == Operation.ReadWrite).map(_.binding)) + case BufferCopyStep(src, dst, _) => + (Seq(src, dst), Seq(dst)) // dst is written + case PipelineBarrier => + (Seq.empty, Seq.empty) + + // Need barrier if this step accesses any buffer that was written by a previous step + // This handles Read-after-Write (RAW) and Write-after-Write (WAW) hazards + val needsBarrier = allBindings.exists(dirty.contains) + + if needsBarrier then + // Reset dirty set to just this step's writes (barrier synchronizes everything) + (steps.appendedAll(Seq(PipelineBarrier, step)), writtenBindings.toSet) + else + // Add this step's writes to dirty set + (steps.appended(step), dirty ++ writtenBindings) + + val commandBuffer = recordCommandBuffer(executeSteps) + + // Submit with timeline semaphore + val waitValue = semaphoreValue + semaphoreValue += 1 + val signalValue = semaphoreValue + submitWithSemaphore(commandBuffer, waitValue, signalValue) + + // Cache the execution + val resultBindings = Layout[RL].toBindings(result) + executionCache(cacheKey) = CachedExecution(resultBindings, commandBuffer, executeSteps, signalValue) + + result + + /** Submit command buffer with timeline semaphore wait/signal for GPU-GPU pipelining */ + private def submitWithSemaphore(commandBuffer: VkCommandBuffer, waitValue: Long, signalValue: Long): Unit = pushStack: stack => + val timelineInfo = VkTimelineSemaphoreSubmitInfo + .calloc(stack) + .sType$Default() + .waitSemaphoreValueCount(1) + .pWaitSemaphoreValues(stack.longs(waitValue)) + .signalSemaphoreValueCount(1) + .pSignalSemaphoreValues(stack.longs(signalValue)) + + val submitInfo = VkSubmitInfo + .calloc(1, stack) + .sType$Default() + .pNext(timelineInfo) + .waitSemaphoreCount(1) + .pWaitSemaphores(stack.longs(timelineSemaphore.get)) + .pWaitDstStageMask(stack.ints(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT)) + .pCommandBuffers(stack.pointers(commandBuffer)) + .pSignalSemaphores(stack.longs(timelineSemaphore.get)) + + check(vkQueueSubmit(queue.get, submitInfo, VK_NULL_HANDLE), "Failed to submit command buffer") + + /** Wait for all GPU work to complete (call before reading results) */ + def sync(): Unit = + if semaphoreValue > 0 then + timelineSemaphore.waitValue(semaphoreValue) + + private def interpret[Params, EL: Layout, RL: Layout]( + execution: GExecution[Params, EL, RL], + params: Params, + layout: EL + )(using VkAllocation): (RL, Seq[ExecutionCall]) = + interpretUncached(execution, params, layout) + + private def interpretUncached[Params, EL: Layout, RL: Layout](execution: GExecution[Params, EL, RL], params: Params, layout: EL)(using VkAllocation, - ): (RL, Seq[ShaderCall]) = + ): (RL, Seq[ExecutionCall]) = val bindingsAcc: mutable.Map[GBinding[?], mutable.Buffer[GBinding[?]]] = mutable.Map.empty def mockBindings[L: Layout](layout: L): L = @@ -95,7 +190,7 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte mapper.fromBindings(res) // noinspection TypeParameterShadow - def interpretImpl[Params, EL: Layout, RL: Layout](execution: GExecution[Params, EL, RL], params: Params, layout: EL): (RL, Seq[ShaderCall]) = + def interpretImpl[Params, EL: Layout, RL: Layout](execution: GExecution[Params, EL, RL], params: Params, layout: EL): (RL, Seq[ExecutionCall]) = execution match case GExecution.Pure() => (layout, Seq.empty) case GExecution.Map(innerExec, map, cmap, cmapP) => @@ -129,6 +224,10 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte case GProgram.StaticDispatch(size) => DispatchType.Direct(size._1, size._2, size._3) // noinspection ScalaRedundantCast (layout.asInstanceOf[RL], Seq(ShaderCall(shader.underlying, shader.shaderBindings(layout), dispatch))) + case bufferCopy: GExecution.BufferCopy[EL] => + val (src, dst) = bufferCopy.getBuffers(layout) + // noinspection ScalaRedundantCast + (layout.asInstanceOf[RL], Seq(BufferCopyCall(src, dst, bufferCopy.sizeBytes))) case _ => ??? val (rl, steps) = interpretImpl(execution, params, mockBindings(layout)) @@ -143,6 +242,8 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte case x: Direct => x case Indirect(buffer, offset) => Indirect(bingingToVk(buffer), offset) ShaderCall(pipeline, nextLayout, nextDispatch) + case BufferCopyCall(src, dst, sizeBytes) => + BufferCopyCall(bingingToVk(src), bingingToVk(dst), sizeBytes) val mapper = Layout[RL] val res = mapper.fromBindings(mapper.toBindings(rl).map(bingingToVk.apply)) @@ -194,11 +295,17 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte val memoryBarrier = VkMemoryBarrier2 // TODO don't synchronise everything .calloc(1, stack) .sType$Default() - .srcStageMask(VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT) - .srcAccessMask(VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT) - .dstStageMask(VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT) - .dstAccessMask(VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT) + .srcStageMask(VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT | VK_PIPELINE_STAGE_2_TRANSFER_BIT) + .srcAccessMask(VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT | VK_ACCESS_2_TRANSFER_READ_BIT | VK_ACCESS_2_TRANSFER_WRITE_BIT) + .dstStageMask(VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT | VK_PIPELINE_STAGE_2_TRANSFER_BIT) + .dstAccessMask(VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT | VK_ACCESS_2_TRANSFER_READ_BIT | VK_ACCESS_2_TRANSFER_WRITE_BIT) + val debugLabel = VkDebugUtilsLabelEXT + .calloc(stack) + .sType$Default() + .pLabelName(stack.UTF8("BARRIER")) + vkCmdBeginDebugUtilsLabelEXT(commandBuffer, debugLabel) + val dependencyInfo = VkDependencyInfo .calloc(stack) .sType$Default() @@ -206,7 +313,20 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte vkCmdPipelineBarrier2(commandBuffer, dependencyInfo) + vkCmdEndDebugUtilsLabelEXT(commandBuffer) + case Dispatch(pipeline, layout, descriptorSets, dispatch) => + // Add debug label for profiling (visible in Nsight Systems) + val dispatchSize = dispatch match + case Direct(x, y, z) => s"${x}x${y}x${z}" + case Indirect(_, _) => "indirect" + val labelName = s"${pipeline.name}[$dispatchSize]" + val debugLabel = VkDebugUtilsLabelEXT + .calloc(stack) + .sType$Default() + .pLabelName(stack.UTF8(labelName)) + vkCmdBeginDebugUtilsLabelEXT(commandBuffer, debugLabel) + vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.get) val pDescriptorSets = stack.longs(descriptorSets.map(_.get)*) @@ -215,6 +335,28 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte dispatch match case Direct(x, y, z) => vkCmdDispatch(commandBuffer, x, y, z) case Indirect(buffer, offset) => vkCmdDispatchIndirect(commandBuffer, VkAllocation.getUnderlying(buffer).buffer.get, offset) + + vkCmdEndDebugUtilsLabelEXT(commandBuffer) + + case BufferCopyStep(src, dst, sizeBytes) => + // Add debug label for profiling + val debugLabel = VkDebugUtilsLabelEXT + .calloc(stack) + .sType$Default() + .pLabelName(stack.UTF8(s"BufferCopy[${sizeBytes}B]")) + vkCmdBeginDebugUtilsLabelEXT(commandBuffer, debugLabel) + + val copyRegion = VkBufferCopy + .calloc(1, stack) + .srcOffset(0) + .dstOffset(0) + .size(sizeBytes) + + val srcBuffer = VkAllocation.getUnderlying(src).buffer.get + val dstBuffer = VkAllocation.getUnderlying(dst).buffer.get + vkCmdCopyBuffer(commandBuffer, srcBuffer, dstBuffer, copyRegion) + + vkCmdEndDebugUtilsLabelEXT(commandBuffer) check(vkEndCommandBuffer(commandBuffer), "Failed to finish recording command buffer") commandBuffer @@ -222,16 +364,21 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte private def getAllBindings(steps: Seq[ExecutionStep]): Seq[GBinding[?]] = steps .flatMap: - case Dispatch(_, layout, _, _) => layout.flatten.map(_.binding) - case PipelineBarrier => Seq.empty + case Dispatch(_, layout, _, _) => layout.flatten.map(_.binding) + case BufferCopyStep(src, dst, _) => Seq(src, dst) + case PipelineBarrier => Seq.empty .distinct object ExecutionHandler: - case class ShaderCall(pipeline: ComputePipeline, layout: ShaderLayout, dispatch: DispatchType) + /** Represents a call to be executed on GPU - either a shader dispatch or buffer copy. */ + sealed trait ExecutionCall + case class ShaderCall(pipeline: ComputePipeline, layout: ShaderLayout, dispatch: DispatchType) extends ExecutionCall + case class BufferCopyCall(src: GBinding[?], dst: GBinding[?], sizeBytes: Int) extends ExecutionCall sealed trait ExecutionStep case class Dispatch(pipeline: ComputePipeline, layout: ShaderLayout, descriptorSets: Seq[DescriptorSet], dispatch: DispatchType) extends ExecutionStep + case class BufferCopyStep(src: GBinding[?], dst: GBinding[?], sizeBytes: Int) extends ExecutionStep case object PipelineBarrier extends ExecutionStep sealed trait DispatchType diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala index 691a10ca..d839aa7c 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkAllocation.scala @@ -25,18 +25,22 @@ import io.computenode.cyfra.dsl.binding.{GBinding, GBuffer, GUniform} import io.computenode.cyfra.spirv.compilers.SpirvProgramCompiler.totalStride import scala.reflect.ClassTag import io.computenode.cyfra.core.GCodec +import io.computenode.cyfra.utility.NVTX -class VkAllocation(val commandPool: CommandPool.Reset, executionHandler: ExecutionHandler)(using Allocator, Device) extends Allocation: +class VkAllocation(val commandPool: CommandPool.Reset, val executionHandler: ExecutionHandler)(using Allocator, Device) extends Allocation: given VkAllocation = this override def submitLayout[L: Layout](layout: L): Unit = + // With timeline semaphores, ExecutionHandler tracks all submissions + // Only sync if there are old-style pending executions (from writes) val executions = Layout[L] .toBindings(layout) .flatMap(x => Try(getUnderlying(x)).toOption) .flatMap(_.execution.fold(Seq(_), _.toSeq)) .filter(_.isPending) - PendingExecution.executeAll(executions, this) + if executions.nonEmpty then + PendingExecution.executeAll(executions, this) extension (buffer: GBinding[?]) def read(bb: ByteBuffer, offset: Int = 0): Unit = @@ -44,10 +48,16 @@ class VkAllocation(val commandPool: CommandPool.Reset, executionHandler: Executi buffer match case VkBinding(buffer: Buffer.HostBuffer) => buffer.copyTo(bb, offset) case binding: VkBinding[?] => + NVTX.push(s"Materialise[$buffer]") binding.materialise(this) + NVTX.pop() + NVTX.push(s"CopyToStaging[$buffer]") val stagingBuffer = getStagingBuffer(size) Buffer.copyBuffer(binding.buffer, stagingBuffer, offset, 0, size, commandPool) + NVTX.pop() + NVTX.push(s"CopyToHost[$buffer]") stagingBuffer.copyTo(bb, 0) + NVTX.pop() stagingBuffer.destroy() case _ => throw new IllegalArgumentException(s"Tried to read from non-VkBinding $buffer") @@ -128,6 +138,9 @@ class VkAllocation(val commandPool: CommandPool.Reset, executionHandler: Executi def addExecution(pe: PendingExecution): Unit = executions += pe + /** Check if there are any pending (not yet submitted) executions from buffer writes. */ + def hasPendingWrites: Boolean = executions.exists(_.isPending) + private val bindings = mutable.Buffer[VkUniform[?] | VkBuffer[?]]() private[cyfra] def close(): Unit = executions.filter(_.isRunning).foreach(_.block()) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala index 1cdcd83d..8e21963f 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkBinding.scala @@ -28,10 +28,8 @@ sealed abstract class VkBinding[T <: Value: {Tag, FromExpr}](val buffer: Buffer) var execution: Either[PendingExecution, mutable.Buffer[PendingExecution]] = Right(mutable.Buffer.empty) def materialise(allocation: VkAllocation)(using Device): Unit = - val allExecs = execution.fold(Seq(_), _.toSeq) // TODO better handle read only executions - allExecs.filter(_.isPending).pipe(PendingExecution.executeAll(_, allocation)) - allExecs.foreach(_.block()) - PendingExecution.cleanupAll(allExecs) + // Sync all GPU work via timeline semaphore before reading + allocation.executionHandler.sync() object VkBinding: def unapply(binding: GBinding[?]): Option[Buffer] = binding match diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala index 050bae1a..c2d7e315 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkCyfraRuntime.scala @@ -1,14 +1,18 @@ package io.computenode.cyfra.runtime import io.computenode.cyfra.core.GProgram.InitProgramLayout +import io.computenode.cyfra.core.binding.BufferRef import io.computenode.cyfra.core.layout.Layout import io.computenode.cyfra.core.{Allocation, CyfraRuntime, GExecution, GProgram, GioProgram, SpirvProgram} +import io.computenode.cyfra.dsl.binding.{WriteBuffer, WriteShared, WriteUniform} +import io.computenode.cyfra.dsl.gio.GIO import io.computenode.cyfra.spirv.compilers.DSLCompiler import io.computenode.cyfra.spirvtools.SpirvToolsRunner import io.computenode.cyfra.vulkan.VulkanContext import io.computenode.cyfra.vulkan.compute.ComputePipeline import java.security.MessageDigest +import scala.annotation.tailrec import scala.collection.mutable class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) extends CyfraRuntime: @@ -28,14 +32,17 @@ class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) ex case _ => throw new IllegalArgumentException(s"Unsupported program type: ${program.getClass.getName}") gProgramCache.update(program, spirvProgram) - shaderCache.getOrElseUpdate(spirvProgram.shaderHash, VkShader(spirvProgram)).asInstanceOf[VkShader[L]] + shaderCache.getOrElseUpdate(spirvProgram.shaderHash, VkShader(spirvProgram, program.name)).asInstanceOf[VkShader[L]] private def compile[Params, L: Layout as l](program: GioProgram[Params, L]): SpirvProgram[Params, L] = - val GioProgram(_, layout, dispatch, _) = program + val GioProgram(_, layout, dispatch, workgroupSize, programName) = program val bindings = l.toBindings(l.layoutRef).toList - val compiled = DSLCompiler.compile(program.body(l.layoutRef), bindings) + val bodyGio = program.body(l.layoutRef) + val compiled = DSLCompiler.compile(bodyGio, bindings, workgroupSize) val optimizedShaderCode = spirvToolsRunner.processShaderCodeWithSpirvTools(compiled) - SpirvProgram((il: InitProgramLayout) ?=> layout(il), dispatch, optimizedShaderCode) + // Extract written binding indices for smarter barrier insertion + val writtenBindingIndices: Set[Int] = VkCyfraRuntime.getWrittenBindingIndices(List(bodyGio), Set.empty) + SpirvProgram((il: InitProgramLayout) ?=> layout(il), dispatch, optimizedShaderCode, writtenBindingIndices, programName) override def withAllocation(f: Allocation => Unit): Unit = context.withThreadContext: threadContext => @@ -49,7 +56,28 @@ class VkCyfraRuntime(spirvToolsRunner: SpirvToolsRunner = SpirvToolsRunner()) ex context.destroy() object VkCyfraRuntime: - def using[T](f: VkCyfraRuntime ?=> T): T = - val runtime = new VkCyfraRuntime() + def using[T](f: VkCyfraRuntime ?=> T)(using spirvTools: SpirvToolsRunner = SpirvToolsRunner()): T = + val runtime = new VkCyfraRuntime(spirvTools) try f(using runtime) finally runtime.close() + + /** Extract binding indices of all GBuffers that are written to in the GIO program. + * Used for smarter barrier insertion - only written buffers cause conflicts. + * Returns Set of layoutOffset values (binding indices). + */ + @tailrec + private[runtime] def getWrittenBindingIndices(pending: List[GIO[?]], acc: Set[Int]): Set[Int] = + pending match + case Nil => acc + case GIO.FlatMap(v, n) :: tail => + getWrittenBindingIndices(v :: n :: tail, acc) + case GIO.Repeat(_, gio, _) :: tail => + getWrittenBindingIndices(gio :: tail, acc) + case GIO.FoldRepeat(_, _, gio, _, _) :: tail => + getWrittenBindingIndices(gio :: tail, acc) + case GIO.ConditionalWhen(_, body) :: tail => + getWrittenBindingIndices(body :: tail, acc) + case WriteBuffer(buffer: BufferRef[?], _, _) :: tail => + getWrittenBindingIndices(tail, acc + buffer.layoutOffset) + case _ :: tail => + getWrittenBindingIndices(tail, acc) diff --git a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala index b63409a3..f4e6206d 100644 --- a/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala +++ b/cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/VkShader.scala @@ -13,11 +13,12 @@ import izumi.reflect.Tag import scala.util.{Failure, Success} -case class VkShader[L](underlying: ComputePipeline, shaderBindings: L => ShaderLayout) +case class VkShader[L](underlying: ComputePipeline, shaderBindings: L => ShaderLayout): + def name: String = underlying.name object VkShader: - def apply[P, L: Layout](program: SpirvProgram[P, L])(using Device): VkShader[L] = - val SpirvProgram(layout, dispatch, _workgroupSize, code, entryPoint, shaderBindings) = program + def apply[P, L: Layout](program: SpirvProgram[P, L], name: String = "Shader")(using Device): VkShader[L] = + val SpirvProgram(layout, dispatch, _workgroupSize, code, entryPoint, shaderBindings, _) = program val shaderLayout = shaderBindings(Layout[L].layoutRef) val sets = shaderLayout.map: set => @@ -29,5 +30,5 @@ object VkShader: DescriptorInfo(kind) DescriptorSetInfo(descriptors) - val pipeline = ComputePipeline(code, entryPoint, LayoutInfo(sets)) + val pipeline = ComputePipeline(code, entryPoint, LayoutInfo(sets), name) VkShader(pipeline, shaderBindings) diff --git a/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/NVTX.scala b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/NVTX.scala new file mode 100644 index 00000000..f97a67bc --- /dev/null +++ b/cyfra-utility/src/main/scala/io/computenode/cyfra/utility/NVTX.scala @@ -0,0 +1,62 @@ +package io.computenode.cyfra.utility + +import com.sun.jna.{Library, Native} + +/** NVTX (NVIDIA Tools Extension) wrapper for CPU profiling markers. + * + * These markers show up in Nsight Systems timeline alongside Vulkan GPU work. + * + * Usage: + * {{{ + * import io.computenode.cyfra.utility.NVTX + * + * NVTX.range("Forward Pass") { + * // code to profile + * } + * + * // Or manual push/pop: + * NVTX.push("Token Generation") + * // ... + * NVTX.pop() + * }}} + * + * Requires: CUDA toolkit installed with libnvtx3interop.so in library path. + * Run with: LD_LIBRARY_PATH=/opt/cuda/lib64:$LD_LIBRARY_PATH + */ +object NVTX: + + private val enabled: Boolean = System.getProperty("io.computenode.cyfra.nvtx.enabled", "false").toBoolean + + private trait NVTXLib extends Library: + def nvtxRangePushA(message: String): Int + def nvtxRangePop(): Int + def nvtxMarkA(message: String): Unit + + private lazy val lib: Option[NVTXLib] = + try + Some(Native.load("nvtx3interop", classOf[NVTXLib])) + catch + case e: UnsatisfiedLinkError => + System.err.println(s"[NVTX] Library not found: ${e.getMessage}") + None + + /** Push a named range onto the NVTX stack. Must be paired with pop(). */ + def push(name: String): Unit = + if enabled then lib.foreach(_.nvtxRangePushA(name)) + + /** Pop the current range from the NVTX stack. */ + def pop(): Unit = + if enabled then lib.foreach(_.nvtxRangePop()) + + /** Place an instant marker (point in time, not a range). */ + def mark(name: String): Unit = + if enabled then lib.foreach(_.nvtxMarkA(name)) + + /** Execute body within a named NVTX range. */ + inline def range[T](name: String)(body: => T): T = + push(name) + try body + finally pop() + + /** Check if NVTX is available. */ + def isAvailable: Boolean = lib.isDefined diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala index 2fe2c35d..5fbe06e8 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/compute/ComputePipeline.scala @@ -17,7 +17,7 @@ import scala.util.{Try, Using} /** @author * MarconZet Created 14.04.2020 */ -private[cyfra] class ComputePipeline(shaderCode: ByteBuffer, functionName: String, layoutInfo: LayoutInfo)(using device: Device) +private[cyfra] class ComputePipeline(shaderCode: ByteBuffer, functionName: String, layoutInfo: LayoutInfo, val name: String = "Shader")(using device: Device) extends VulkanObjectHandle: private val shader: Long = pushStack: stack => // TODO khr_maintenance5 diff --git a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala index f8661f6d..dc45d0e9 100644 --- a/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala +++ b/cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/core/Instance.scala @@ -26,7 +26,9 @@ import scala.util.chaining.* object Instance: private val ValidationLayer: String = "VK_LAYER_KHRONOS_validation" private val ValidationLayersExtensions: Seq[String] = - List(VK_EXT_DEBUG_REPORT_EXTENSION_NAME, VK_EXT_DEBUG_UTILS_EXTENSION_NAME, VK_EXT_LAYER_SETTINGS_EXTENSION_NAME) + List(VK_EXT_DEBUG_REPORT_EXTENSION_NAME, VK_EXT_LAYER_SETTINGS_EXTENSION_NAME) + // Always load debug utils for profiling markers (even without validation) + private val AlwaysEnabledExtensions: Seq[String] = List(VK_EXT_DEBUG_UTILS_EXTENSION_NAME) private val MoltenVkExtensions: Seq[String] = List(VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME) lazy val (extensions, layers): (Seq[String], Seq[String]) = pushStack: stack => @@ -126,6 +128,7 @@ private[cyfra] class Instance(enableValidationLayers: Boolean, enablePrinting: B buf.toSet val extensions = mutable.Buffer.from(Instance.MoltenVkExtensions) + extensions.addAll(Instance.AlwaysEnabledExtensions) if enableValidationLayers then extensions.addAll(Instance.ValidationLayersExtensions) val filteredExtensions = extensions.filter(ext =>