I am creating a deep learning image classification model in Pythorch.

Asked 2 years ago, Updated 2 years ago, 140 views

When I tried torch.sqrt() as part of the loss function of the model, I encountered a problem where nan occurs when backing up.
The reason seems to be that the vector element entered in torch.sqrt() is very small.

If torch.sqrt() has a small input, 1/(2*torch.sqrt()) appears to be inf at the time of backup...

If anyone knows how to deal with it, please let me know.

Error Message

Traceback (most recent call last): 
File "main_label_grad.py", 
line 504, in<module>model_g=main() 
File "main_label_grad.py", 
line 459, in main tr_acc, tr_acc5, tr_los, grad_train, last_v4 = train(train_loader, net_c, net_t, optimizer_c, optimizer_t, epoch, args, log_G, args. noise_dim, grad_train_old=None)=vone 
File "main_label_grad.py", 
line 320, intrain loss_trans.backward() 
File "C:\Users\GUESTUSER\.conda\envs\tf37\lib\site-packages\torch\tensor.py", 
line 118, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph) 
File "C:\Users\GUESTUSER\.conda\envs\tf37\lib\site-packages\torch\autograd\_init__.py", 
line93, inbackward allow_unreachable=True)#allow_unreachable flag RuntimeError: Function 'SqrtBackward' returned nan values in it's 0th output.

The loss_trans above is the objective function of the model and corresponds to the first return value of the following functions:
The following function (new_norm) return torch.sqrt(v4_ema) has a small v4_ema of nan.

def new_norm(v,epoch,iter,last_v4=None): 
 v2 = torch.power(v,2) 
 v4 = torch.power(v,4) 
 v4_ema = ema(v4, epoch,iter,last_v4) 
 epsilon=torch.ones(v4_ema.size(0))*1e-10 epsilon=epsilon.cuda() 
 return(v2/(torch.sqrt(v4_ema)+epsilon)) .sum()/v4_ema.size(0), v4_ema

python pytorch

2022-09-29 20:30

1 Answers

Not limited to PyTorch, sqrt(x) has 1/(2*sqrt(x)) differentiation, so if x is 0, sqrt differentiation becomes inf, which is why NaN appears in the subsequent calculations.

If you let them wear clogs like torch.sqrt(torch.clamp(x,min=1.0e-6)) so that they don't get a value close to zero, I think they won't get any errors.


2022-09-29 20:30

If you have any answers or tips


© 2024 OneMinuteCode. All rights reserved.