Understanding the Link Between Loss Functions and Optimization Algorithms in PyTorch

Asked 2 years ago, Updated 2 years ago, 131 views

When learning DL, I think it's a flow of using models to make predictions → calculate losses with loss functions → optimize loss functions (reverse propagation & parameter update). When doing PyTorch, is the information "I'll optimize this loss function with this optimization algorithm" automatically connected.

Specifically,

#correct code
# cross entropy error function
loss_fnc=nn.CrossEntropyLoss()
 
# SGD
optimizer=optim.SGD(net.parameters(), lr=0.01)#Learning rate is 0.01

not

# Incorrect code
# cross entropy error function
loss_fnc=nn.CrossEntropyLoss()
 
# SGD
# Assume that there are parameters that specify the function to perform the optimization.
optimizer=optim.SGD(net.parameters(), lr=0.01, target=loss_fnc)#Learning rate is 0.01

Is it automatically connected without doing so?

Also, where is the connection made?

python deep-learning pytorch

2022-09-29 22:28

1 Answers

Optimization algorithms and loss functions are independent, so there is no explicit connection.

  • The loss function calculates the gradient of the parameter by .backward().Gradient values accumulate in the parameters themselves.
  • The optimization algorithm updates the parameter values using the gradient information that the parameter itself has.The optimization algorithm does not need to know what the loss function is because it only looks at the gradient.
#Calculating Forecasts and Losses
pred = net(x) 
loss=loss_fn(pred,y)

# back propagation
optimizer.zero_grad()# Update gradient to 0 for all parameter objects
loss_fn.backward()# where gradient values are recorded in the model parameter object
optimizer.step()# Update parameter values using parameter object gradient values

The optimization algorithm uses net.parameters() to get a list of parameters to be updated.For example, if you only pass some parameters of the model here, you only update some of them.

By the way, optimizer.zero_grad() updates the gradient of all the parameters in net.parameters() to 0.Loss_fn.backward() accumulates the previous gradient + new gradient, so we reset the previous gradient information before calculating the new gradient.


2022-09-29 22:28

If you have any answers or tips


© 2024 OneMinuteCode. All rights reserved.