Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions compute/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ import (
// FusedRMSNormer is an optional interface for engines that support GPU-accelerated
// fused RMSNorm. Layers can type-assert to this to use the fused kernel.
// Returns (output, scales) where scales contains per-row rsqrt values for backward pass.
//
// This API is not covered by the v1 stability guarantee.
type FusedRMSNormer interface {
FusedRMSNormGPU(input, weight *tensor.TensorNumeric[float32], epsilon float32) (output, scales *tensor.TensorNumeric[float32], err error)
}

// PoolResetter is an optional interface for engines that use arena-based
// memory pools. Call ResetPool() at the start of each forward pass to
// reclaim all per-pass intermediate allocations in O(1).
//
// This API is not covered by the v1 stability guarantee.
type PoolResetter interface {
ResetPool()
}
Expand All @@ -27,6 +31,8 @@ type PoolResetter interface {
// model weights to device memory at load time. This eliminates per-operation
// host-to-device copies during inference. Each tensor's storage is replaced
// in-place from CPUStorage to device-resident storage.
//
// This API is not covered by the v1 stability guarantee.
type WeightUploader interface {
UploadWeights(tensors []*tensor.TensorNumeric[float32]) error
}
Expand All @@ -35,12 +41,16 @@ type WeightUploader interface {
// C = A * B^T without explicitly transposing B. This avoids an extra
// GPU allocation and kernel launch for the transpose operation.
// A is [batch, m, k], B is [batch, n, k], result is [batch, m, n].
//
// This API is not covered by the v1 stability guarantee.
type TransposeBMatMuler[T tensor.Numeric] interface {
MatMulTransposeB(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
}

// StreamProvider is an optional interface for engines that expose their
// underlying GPU stream for CUDA graph capture.
//
// This API is not covered by the v1 stability guarantee.
type StreamProvider interface {
// Stream returns the engine's GPU stream as an unsafe.Pointer (cudaStream_t).
Stream() unsafe.Pointer
Expand All @@ -49,20 +59,26 @@ type StreamProvider interface {
// GPUStreamAccessor is an optional interface for engines that provide their
// gpuapi.Stream for async memory operations (e.g., KV cache D2D copies
// during CUDA graph capture).
//
// This API is not covered by the v1 stability guarantee.
type GPUStreamAccessor interface {
GPUStream() gpuapi.Stream
}

// GPUArgmaxer is an optional interface for engines that can compute argmax
// entirely on GPU, returning just the index without copying logits to host.
// This eliminates the ~1MB D2H copy per token for greedy decoding.
//
// This API is not covered by the v1 stability guarantee.
type GPUArgmaxer interface {
GPUArgmax(t *tensor.TensorNumeric[float32]) (int, error)
}

// FP16ToF32Converter is an optional interface for engines that can convert
// a tensor with Float16Storage to a regular float32 GPU tensor. This is used
// at the end of the FP16 forward pass to produce F32 logits for sampling.
//
// This API is not covered by the v1 stability guarantee.
type FP16ToF32Converter interface {
ConvertFP16ToF32(t *tensor.TensorNumeric[float32]) (*tensor.TensorNumeric[float32], error)
}
Expand All @@ -72,6 +88,8 @@ type FP16ToF32Converter interface {
// supports paged attention, callers can pass block pointers and indices
// instead of contiguous KV tensors.
//
// This API is not covered by the v1 stability guarantee.
//
// Q: [batch*numQHeads, headDim]
// blockPtrsK: device array of float* pointers to K blocks
// blockPtrsV: device array of float* pointers to V blocks
Expand Down
2 changes: 2 additions & 0 deletions compute/engine_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
)

// TraceRecorder is the interface used by EngineProxy to record traced operations.
//
// This API is not covered by the v1 stability guarantee.
type TraceRecorder[T tensor.Numeric] interface {
Record(opName string, inputs []*tensor.TensorNumeric[T], output *tensor.TensorNumeric[T], extra map[string]any)
RecordMultiOutput(opName string, inputs []*tensor.TensorNumeric[T], outputs []*tensor.TensorNumeric[T], extra map[string]any)
Expand Down
2 changes: 2 additions & 0 deletions compute/flash_decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
// O: [batch * numQHeads, headDim] — output (caller-allocated).
//
// Supports GQA: numQHeads must be a multiple of numKVHeads.
//
// This API is not covered by the v1 stability guarantee.
func FlashDecode(
Q, K, V, O []float32,
batch, numQHeads, numKVHeads, kvLen, headDim int,
Expand Down
2 changes: 2 additions & 0 deletions compute/fused_add_rmsnorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
// FusedAddRMSNormProvider is implemented by engines that support fused
// residual-add + RMS normalization in a single GPU kernel launch.
// This eliminates one kernel launch per fusion point (2 per transformer layer).
//
// This API is not covered by the v1 stability guarantee.
type FusedAddRMSNormProvider[T tensor.Numeric] interface {
// GPUFusedAddRMSNorm computes:
// sum = input + residual
Expand Down
2 changes: 2 additions & 0 deletions compute/fused_norm_add.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
// RMSNorm + elementwise Add in a single GPU kernel launch.
// output = rmsnorm(input, weight, eps) + residual.
// This eliminates one kernel launch per fusion point.
//
// This API is not covered by the v1 stability guarantee.
type FusedNormAddProvider[T tensor.Numeric] interface {
// GPUFusedNormAdd computes:
// normed = rmsnorm(input, weight, eps)
Expand Down
2 changes: 2 additions & 0 deletions compute/fused_qk_norm_rope.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
// per-head QK RMSNorm + RoPE in a single GPU kernel launch.
// This replaces 4 kernel launches (Q_norm + K_norm + Q_RoPE + K_RoPE)
// with 1 per GQA layer during decode.
//
// This API is not covered by the v1 stability guarantee.
type FusedQKNormRoPEProvider[T tensor.Numeric] interface {
// GPUFusedQKNormRoPE applies per-head RMSNorm + RoPE to combined Q+K data.
// input: [totalHeads, headDim] (Q heads then K heads, contiguous).
Expand Down
2 changes: 2 additions & 0 deletions compute/fused_rmsnorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
// Weight shape: [D].
// Returns (output, scales) where output has same shape as input and scales
// has shape [..., 1] containing the per-row rsqrt(mean(x^2)+eps) values.
//
// This API is not covered by the v1 stability guarantee.
func FusedRMSNorm(input, weight *tensor.TensorNumeric[float32], epsilon float32) (output, scales *tensor.TensorNumeric[float32], err error) {
shape := input.Shape()
D := shape[len(shape)-1]
Expand Down
4 changes: 4 additions & 0 deletions compute/fused_rope.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@ import (
)

// FusedRoPEProvider is implemented by engines that support fused GPU RoPE.
//
// This API is not covered by the v1 stability guarantee.
type FusedRoPEProvider[T tensor.Numeric] interface {
GPUFusedRoPE(input, cosAngles, sinAngles *tensor.TensorNumeric[T], rotaryDim int) (*tensor.TensorNumeric[T], error)
}

// FusedRoPE applies rotary position embeddings in a single pass.
// Input shape: [batch, seq_len, head_dim] where head_dim is even.
//
// This API is not covered by the v1 stability guarantee.
// cos/sin shape: [seq_len, half_dim] (precomputed angles).
// rotaryDim: number of dimensions that receive rotation (<= head_dim, must be even).
// For each position (b, s):
Expand Down
2 changes: 2 additions & 0 deletions compute/fused_scaled_softmax.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
// FusedScaledSoftmaxProvider is implemented by engines that support fused GPU scaled softmax.
// It computes output = softmax(input * scale) in a single kernel launch,
// eliminating the MulScalar + Softmax chain (saves 1 kernel launch per call).
//
// This API is not covered by the v1 stability guarantee.
type FusedScaledSoftmaxProvider[T tensor.Numeric] interface {
GPUScaledSoftmax(input *tensor.TensorNumeric[T], scale float32, axis int) (*tensor.TensorNumeric[T], error)
}
2 changes: 2 additions & 0 deletions compute/fused_silugate.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
// SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x)).
// gate and up must have the same shape.
// This avoids materializing separate sigmoid, mul, and mul intermediate tensors.
//
// This API is not covered by the v1 stability guarantee.
func FusedSiLUGate(gate, up *tensor.TensorNumeric[float32]) (*tensor.TensorNumeric[float32], error) {
gShape := gate.Shape()
uShape := up.Shape()
Expand Down
2 changes: 2 additions & 0 deletions compute/fused_swiglu.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
// FusedSwiGLUProvider is implemented by engines that support fused GPU SwiGLU.
// It computes output[i] = w1[i] * sigmoid(w1[i]) * w3[i] in a single kernel,
// eliminating the Concat + Split + sigmoid + Mul + Mul chain.
//
// This API is not covered by the v1 stability guarantee.
type FusedSwiGLUProvider[T tensor.Numeric] interface {
GPUFusedSwiGLU(w1, w3 *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
}
8 changes: 7 additions & 1 deletion compute/testable_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import (
"github.com/zerfoo/ztensor/tensor"
)

// TestableEngine extends CPUEngine with methods that allow controlled error injection
// TestableEngine extends CPUEngine with methods that allow controlled error injection.
// This enables testing of previously unreachable error paths.
//
// This API is not covered by the v1 stability guarantee.
type TestableEngine[T tensor.Numeric] struct {
*CPUEngine[T]
}
Expand All @@ -23,6 +25,8 @@ func NewTestableEngine[T tensor.Numeric](ops numeric.Arithmetic[T]) *TestableEng
}

// FailableTensor wraps a tensor and can be configured to fail on specific operations.
//
// This API is not covered by the v1 stability guarantee.
type FailableTensor[T tensor.Numeric] struct {
*tensor.TensorNumeric[T]
failOnSet bool
Expand Down Expand Up @@ -140,6 +144,8 @@ func (e *TestableEngine[T]) TestableTranspose(_ context.Context, a *tensor.Tenso
}

// FailableZeroer can be configured to fail on Zero operations.
//
// This API is not covered by the v1 stability guarantee.
type FailableZeroer[T tensor.Numeric] struct {
engine *TestableEngine[T]
failZero bool
Expand Down
Loading
Loading