RidgeRegression

class torch_rc.optim.RidgeRegression(params: Iterable, l2_reg: float = 1)[source]

Implements a Ridge regression algorithm.

Parameters
  • params – iterable of parameters to optimize. Since this algorithm can only be used to train linear layers, the number of parameters should be exactly 2 (one for the weights and one for the biases).

  • l2_reg – regularization strength. Must be a positive floating point number.

fit(input, expected)[source]

Updates the parameters with the solution of the optimization problem.

Parameters
  • input – input tensor of shape (batch, n_features)

  • expected – target tensor of shape (batch, n_targets)