RidgeClassification

class torch_rc.optim.RidgeClassification(params: Iterable, l2_reg: float = 1)[source]

Implements a Ridge classification algorithm.

This classifier first converts the target values into {-1, 1} and then treats the problem as a regression task.

Parameters
  • params – iterable of parameters to optimize. Since this algorithm can only be used to train linear layers, the number of parameters should be exactly 2 (one for the weights and one for the biases).

  • l2_reg – regularization strength. Must be a positive floating point number.

fit(input, expected)[source]

Updates the parameters with the solution of the optimization problem.

Parameters
  • input – input tensor of shape (batch, n_features)

  • expected – target tensor of shape (batch,) and of type torch.long