Вы находитесь на странице: 1из 2

public class BatchGradientDescent {

private double[] sample_;


private double[] y_;
private double[] w_;
private double[] buffer_w_;
private double learning_rate_;
private int sample_dimension_;
private int num_of_sample_;
public BatchGradientDescent(int sample_dimension, double[] sample, double[] y,
double learning_rate) {
sample_ = sample;
y_ = y;
w_ = new double[sample_dimension];
buffer_w_ = new double[sample_dimension];
learning_rate_ = learning_rate;
sample_dimension_ = sample_dimension;
num_of_sample_ = sample.length / sample_dimension;
}

// let J(w) = 1/2 sum i (w * x[i] - y)^2


// d/dwj J(w) = sum i (w * x[i] - y) x[i][j]
public double Update() {
double gradient_magnitude = 0;
for (int j = 0; j < sample_dimension_; ++j) {
double gradient = 0;
for (int i = 0; i < num_of_sample_; ++i) {
double current_gradient = -y_[i];
for (int k = 0; k < sample_dimension_; ++k) {
current_gradient += sample_[i * sample_dimension_ + k] * w_[k];
}
current_gradient *= sample_[i * sample_dimension_ + j];
gradient += current_gradient;
}
gradient /= num_of_sample_;
buffer_w_[j] = w_[j] - learning_rate_ * gradient;
gradient_magnitude += gradient * gradient;
}
gradient_magnitude = Math.sqrt(gradient_magnitude);
for (int j = 0; j < sample_dimension_; ++j) {
w_[j] = buffer_w_[j];
}
return gradient_magnitude;
}

public double[] GetWeight() {


return w_;
}

public double ComputeError() {


double error = 0;
for (int i = 0; i < num_of_sample_; ++i) {
double derror = 0;
for (int j = 0; j < sample_dimension_; ++j) {
derror += w_[j] * sample_[i * sample_dimension_ + j] - y_[i];
}
derror *= derror;
error += derror;
}
return error;
}
}

Вам также может понравиться