Skip to content

Unexpected Tensor Rank Error when Using LayerNorm Layer #9

@Florian9M

Description

@Florian9M

Description

I used the code provided in the README.md file to define the LayerNorm layer. However, I encountered an error indicating an unexpected tensor rank.

Error

Running `target\debug\candle-roberta.exe`
Error: unexpected rank, expected: 3, got: 1 ([3])

main.rs

mod roberta_model;

use candle_core::{DType, Device, Module, Tensor};
use crate::roberta_model::LayerNorm;

fn main() -> anyhow::Result<()> {
    let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
    let b_gen = Tensor::new(-2f32, &Device::Cpu)?;

    // initialize a layer norm layer
    let layer_norm = LayerNorm::new(w_gen, b_gen, 1f64);

    let data: [u32; 3] = [1u32, 2, 3];
    let input_tensor = Tensor::new(&data, &Device::Cpu)?;
    let normalized_tensor = layer_norm.forward(&input_tensor)?;
    Ok(())
}

roberta_model.rs

use candle_core::{DType, Tensor};
use candle_nn::{Embedding, Linear, VarBuilder};

fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> anyhow::Result<Embedding> {
    let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
    Ok(Embedding::new(embeddings, hidden_size))
}

fn linear(size1: usize, size2: usize, vb: VarBuilder) -> anyhow::Result<Linear> {
    let weight = vb.get((size2, size1), "weight")?;
    let bias = vb.get(size2, "bias")?;
    Ok(Linear::new(weight, Some(bias)))
}

pub struct LayerNorm {
    weight: Tensor, // Weight vector of the LayerNorm Layer
    bias: Tensor, // Bias vector of the LayerNorm Layer
    eps: f64, // Epsilon value for numerical stability
}

impl LayerNorm {
    // Constructor for LayerNorm
    pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
        Self { weight, bias, eps }
    }

    pub fn forward(&self, x: &Tensor) -> anyhow::Result<Tensor> {
        let x_dtype = x.dtype(); // Get the data type of the input tensor
        let internal_dtype = match x_dtype {
            DType::F16 | DType::BF16 => DType::F32,
            d => d,
        };
        let (_bsize, _seq_len, hidden_size) = x.dims3()?; // Get the dimensions of the input tensor
        let x = x.to_dtype(internal_dtype)?;
        let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; // Get the mean of the input tensor and divide by the hidden size
        let x = x.broadcast_sub(&mean_x)?; // Subtract the mean from the input tensor
        let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; // Get the squared norm of the input tensor and divide by the hidden size
        let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; // Get the normalized input
        let x = x_normed
            .to_dtype(x_dtype)?
            .broadcast_mul(&self.weight)?
            .broadcast_add(&self.bias)?;
        Ok(x)
    }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions