Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions infini_train/include/autograd/linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@ class Tensor;

namespace infini_train::autograd {

struct LinearGradFlags {
bool input = false;
bool weight = false;
bool bias = false;
};

class Linear : public Function {
public:
static constexpr char kType[] = "LinearFunction";
Expand Down
30 changes: 21 additions & 9 deletions infini_train/src/autograd/linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,29 @@ std::vector<std::shared_ptr<Tensor>> Linear::Backward(const std::vector<std::sha
const auto &grad_output = grad_outputs[0];

CHECK(!needs_input_grad_.empty()) << "needs_input_grad_ not populated in Linear::Backward";
LinearGradFlags grad_flags = {.input = needs_input_grad_[0],
.weight = needs_input_grad_.size() > 1 && needs_input_grad_[1],
.bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2]};
bool need_grad_input = needs_input_grad_[0];
bool need_grad_weight = needs_input_grad_.size() > 1 && needs_input_grad_[1];
bool need_grad_bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2];

auto device = grad_output->GetDevice().type();
// TODO: skip autograd graph construction entirely when no input requires grad
auto [grad_input, grad_weight, grad_bias]
= Dispatcher::Instance()
.Call<std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>>(
{device, "LinearBackward"}, input, weight, transpose_, in_features_, out_features_, input_dims_,
grad_output, bias_, grad_flags);

std::shared_ptr<Tensor> grad_input = nullptr;
std::shared_ptr<Tensor> grad_weight = nullptr;
std::shared_ptr<Tensor> grad_bias = nullptr;

if (need_grad_input) {
grad_input = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>(
{device, "LinearBackwardInput"}, weight, grad_output, transpose_, in_features_, out_features_, input_dims_);
}
if (need_grad_weight) {
grad_weight = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>(
{device, "LinearBackwardWeight"}, input, grad_output, transpose_, in_features_, out_features_);
}
if (need_grad_bias) {
grad_bias = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "LinearBackwardBias"}, grad_output,
out_features_);
}

if (bias_) {
return {grad_input, grad_weight, grad_bias};
} else {
Expand Down
33 changes: 27 additions & 6 deletions infini_train/src/autograd/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,17 @@ void Matmul::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tens
// FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be
// determined by autocast, not derived from output->Dtype().
auto compute_dtype = output->Dtype();
saved_tensors_ = {
input1->Dtype() == compute_dtype ? input1 : std::make_shared<Tensor>(input1->To(compute_dtype)),
input2->Dtype() == compute_dtype ? input2 : std::make_shared<Tensor>(input2->To(compute_dtype)),

// grad_input1 = grad_output @ input2^T, so input2 is needed
// grad_input2 = grad_output^T @ input1, so input1 is needed
bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0];
bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1];

auto cast = [&](const std::shared_ptr<Tensor> &t) {
return t->Dtype() == compute_dtype ? t : std::make_shared<Tensor>(t->To(compute_dtype));
};

saved_tensors_ = {need_grad_input2 ? cast(input1) : nullptr, need_grad_input1 ? cast(input2) : nullptr};
out_features_ = output->Dims()[0];
}

Expand All @@ -45,10 +52,24 @@ std::vector<std::shared_ptr<Tensor>> Matmul::Backward(const std::vector<std::sha
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];

CHECK(!needs_input_grad_.empty()) << "needs_input_grad_ not populated in Matmul::Backward";
bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0];
bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1];

auto device = input1->GetDevice().type();
auto [grad_input1, grad_input2]
= Dispatcher::Instance().Call<std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>>(
{device, "MatmulBackward"}, input1, input2, grad_output);

std::shared_ptr<Tensor> grad_input1 = nullptr;
std::shared_ptr<Tensor> grad_input2 = nullptr;

if (need_grad_input1) {
grad_input1 = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardInput1"}, input2,
grad_output, input1->Dims());
}
if (need_grad_input2) {
grad_input2 = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardInput2"}, input1,
grad_output, input2->Dims());
}

return {grad_input1, grad_input2};
}
} // namespace infini_train::autograd
155 changes: 90 additions & 65 deletions infini_train/src/kernels/cpu/linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

#include "glog/logging.h"

#include "infini_train/include/autograd/linear.h"
#include "infini_train/include/dispatcher.h"
#include "infini_train/include/tensor.h"

Expand Down Expand Up @@ -51,38 +50,71 @@ std::shared_ptr<Tensor> MatmulForward(const std::shared_ptr<Tensor> &input, cons
return {output};
}

std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other,
const std::shared_ptr<Tensor> &grad_output) {
std::shared_ptr<Tensor> MatmulBackwardInput1(const std::shared_ptr<Tensor> &other,
const std::shared_ptr<Tensor> &grad_output,
const std::vector<int64_t> &input_dims) {
/*
grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T
grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n]
*/
const auto &input_dims = input->Dims();
const auto &other_dims = other->Dims();
const auto &grad_output_dims = grad_output->Dims();

CHECK_GE(other_dims.size(), 2);
CHECK_EQ(other_dims.size(), grad_output_dims.size());

const int64_t m = grad_output_dims[grad_output_dims.size() - 2];
const int64_t k = other_dims[other_dims.size() - 2];
const int64_t n = grad_output_dims[grad_output_dims.size() - 1];

const int64_t bs
= std::accumulate(grad_output_dims.rbegin() + 2, grad_output_dims.rend(), 1, std::multiplies<int64_t>{});
for (int64_t i = 0; i < grad_output_dims.size() - 2; ++i) {
CHECK_EQ(grad_output_dims[i], other_dims[i]) << "Batch dims must match";
}

auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
grad_input->Fill<float>(0.0f);

for (int64_t b = 0; b < bs; ++b) {
for (int64_t i = 0; i < m; ++i) {
for (int64_t j = 0; j < n; ++j) {
const float grad = static_cast<float *>(grad_output->DataPtr())[b * m * n + i * n + j];
for (int64_t p = 0; p < k; ++p) {
const auto other_idx = b * k * n + p * n + j;
static_cast<float *>(grad_input->DataPtr())[b * m * k + i * k + p]
+= grad * static_cast<const float *>(other->DataPtr())[other_idx];
}
}
}
}
return grad_input;
}

std::shared_ptr<Tensor> MatmulBackwardInput2(const std::shared_ptr<Tensor> &input1,
const std::shared_ptr<Tensor> &grad_output,
const std::vector<int64_t> &other_dims) {
/*
grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n]
*/
const auto &input_dims = input1->Dims();
const auto &grad_output_dims = grad_output->Dims();

CHECK_GE(input_dims.size(), 2);
CHECK_EQ(input_dims.size(), other_dims.size());
CHECK_EQ(input_dims.size(), grad_output_dims.size());

const int64_t m = input_dims[input_dims.size() - 2];
const int64_t k = input_dims[input_dims.size() - 1];
CHECK_EQ(k, other_dims[other_dims.size() - 2]);
const int64_t n = other_dims[other_dims.size() - 1];

const int64_t n = grad_output_dims[grad_output_dims.size() - 1];
CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]);
CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]);
CHECK_EQ(k, other_dims[other_dims.size() - 2]);

const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies<int64_t>{});
for (int64_t i = 0; i < input_dims.size() - 2; ++i) {
CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match";
CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match";
CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match";
}

auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
auto grad_other = std::make_shared<Tensor>(other_dims, DataType::kFLOAT32);
grad_input->Fill<float>(0.0f);
grad_other->Fill<float>(0.0f);

for (int64_t b = 0; b < bs; ++b) {
Expand All @@ -91,16 +123,13 @@ MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
const float grad = static_cast<float *>(grad_output->DataPtr())[b * m * n + i * n + j];
for (int64_t p = 0; p < k; ++p) {
const auto input_idx = b * m * k + i * k + p;
const auto other_idx = b * k * n + p * n + j;
static_cast<float *>(grad_input->DataPtr())[input_idx]
+= grad * static_cast<const float *>(other->DataPtr())[other_idx];
static_cast<float *>(grad_other->DataPtr())[other_idx]
+= grad * static_cast<const float *>(input->DataPtr())[input_idx];
static_cast<float *>(grad_other->DataPtr())[b * k * n + p * n + j]
+= grad * static_cast<const float *>(input1->DataPtr())[input_idx];
}
}
}
}
return {grad_input, grad_other};
return grad_other;
}

std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight,
Expand Down Expand Up @@ -146,71 +175,67 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
return output;
}

// TODO(dcj): support linear without bias later
std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight, bool transpose,
int64_t in_features, int64_t out_features, const std::vector<int64_t> &input_dims,
const std::shared_ptr<Tensor> &grad_output, bool bias,
infini_train::autograd::LinearGradFlags grad_flags) {
std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weight,
const std::shared_ptr<Tensor> &grad_output, bool transpose,
int64_t in_features, int64_t out_features,
const std::vector<int64_t> &input_dims) {
/*
transpose: grad_input = grad_output * weight
grad_input[*, in_features] = grad_output[*, out_features] * weight[out_features, in_features]
grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features]
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)

!transpose: grad_input = grad_output * weight^T
grad_input[*, in_features] = grad_output[_, out_features] * weight[in_features, out_features]^T
grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features]
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
*/
const auto compute_grad_input = grad_flags.input;
const auto compute_grad_weight = grad_flags.weight;
const auto compute_grad_bias = grad_flags.bias;

CHECK_GE(input_dims.size(), 2);

std::vector<int64_t> weight_dims
= transpose ? std::vector<int64_t>{out_features, in_features} : std::vector<int64_t>{in_features, out_features};

std::shared_ptr<Tensor> grad_input = nullptr;
std::shared_ptr<Tensor> grad_weight = nullptr;
std::shared_ptr<Tensor> grad_bias = nullptr;

if (compute_grad_input) {
CHECK(weight != nullptr) << "compute_grad_input=true but weight is nullptr (selective save mismatch)";
grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
if (transpose) {
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix();
} else {
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose();
}
auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
if (transpose) {
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix();
} else {
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose();
}
return grad_input;
}

if (compute_grad_weight) {
CHECK(input != nullptr) << "compute_grad_weight=true but input is nullptr (selective save mismatch)";
grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32);
if (transpose) {
grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix();
} else {
grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix();
}
}
std::shared_ptr<Tensor> LinearBackwardWeight(const std::shared_ptr<Tensor> &input,
const std::shared_ptr<Tensor> &grad_output, bool transpose,
int64_t in_features, int64_t out_features) {
/*
transpose:
grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features]

if (compute_grad_bias && bias) {
grad_bias = std::make_shared<Tensor>(std::vector<int64_t>{out_features}, DataType::kFLOAT32);
grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum();
!transpose:
grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features]
*/
std::vector<int64_t> weight_dims
= transpose ? std::vector<int64_t>{out_features, in_features} : std::vector<int64_t>{in_features, out_features};
auto grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32);
if (transpose) {
grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix();
} else {
grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix();
}
return grad_weight;
}

return {grad_input, grad_weight, grad_bias};
std::shared_ptr<Tensor> LinearBackwardBias(const std::shared_ptr<Tensor> &grad_output, int64_t out_features) {
/*
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
*/
auto grad_bias = std::make_shared<Tensor>(std::vector<int64_t>{out_features}, DataType::kFLOAT32);
grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum();
return grad_bias;
}
} // namespace infini_train::kernels::cpu

#define REGISTER_CPU_LINEAR_KERNEL(kernel_name) \
REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name)

REGISTER_CPU_LINEAR_KERNEL(MatmulForward)
REGISTER_CPU_LINEAR_KERNEL(MatmulBackward)
REGISTER_CPU_LINEAR_KERNEL(MatmulBackwardInput1)
REGISTER_CPU_LINEAR_KERNEL(MatmulBackwardInput2)
REGISTER_CPU_LINEAR_KERNEL(LinearForward)
REGISTER_CPU_LINEAR_KERNEL(LinearBackward)
REGISTER_CPU_LINEAR_KERNEL(LinearBackwardInput)
REGISTER_CPU_LINEAR_KERNEL(LinearBackwardWeight)
REGISTER_CPU_LINEAR_KERNEL(LinearBackwardBias)

#undef REGISTER_CPU_LINEAR_KERNEL
Loading
Loading