当前位置:网站首页>RuntimeError: Providing a bool or integral fill value without setting the optional `dtype` or `out`

RuntimeError: Providing a bool or integral fill value without setting the optional `dtype` or `out`

2022-04-23 19:13:00 pflik-sj

一、问题描述

报错提示:

RuntimeError: Providing a bool or integral fill value without setting the optional `dtype` or `out` arguments is currently unsupported. In PyTorch 1.7, when `dtype` and `out` are not set a bool fill value will return a tensor of torch.bool dtype, and an integral fill value will return a tensor of torch.long dtype.

报错全文提示:

 File "main.py", line 62, in <module>
    run(train_dataset, val_dataset, test_dataset, args.save_dir, args.log_dir, model, args.epochs, args.batch_size, args.lr, args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay, args.energy_and_force, args.p)
  File "**项目路径**/SphereNet_3D/spherenet_code/train.py", line 42, in run
    train_loss = train(model, optimizer, train_loader, energy_and_force, p, loss_func, device)
  File "**项目路径**/SphereNet_3D/spherenet_code/train.py", line 75, in train
    for step, batch_data in enumerate(tqdm(train_loader)):
  File "**虚拟环境路径**/miniconda3/envs/spherenet/lib/python3.7/site-packages/tqdm/std.py", line 1195, in __iter__
    for obj in iterable:
  File "**虚拟环境路径**/miniconda3/envs/spherenet/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 363, in __next__
    data = self._next_data()
  File "**虚拟环境路径**/miniconda3/envs/spherenet/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 403, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "**虚拟环境路径**/miniconda3/envs/spherenet/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "**虚拟环境路径**/miniconda3/envs/spherenet/lib/python3.7/site-packages/torch_geometric/loader/dataloader.py", line 20, in __call__
    self.exclude_keys)
  File "**虚拟环境路径**/miniconda3/envs/spherenet/lib/python3.7/site-packages/torch_geometric/data/batch.py", line 74, in from_data_list
    exclude_keys=exclude_keys,
  File "**虚拟环境路径**/miniconda3/envs/spherenet/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 109, in collate
    out_store.batch = repeat_interleave(repeats, device=device)
  File "**虚拟环境路径**/miniconda3/envs/spherenet/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 208, in repeat_interleave
    outs = [torch.full((n, ), i, device=device) for i, n in enumerate(repeats)]
  File "**虚拟环境路径**/miniconda3/envs/spherenet/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 208, in <listcomp>
    outs = [torch.full((n, ), i, device=device) for i, n in enumerate(repeats)]
RuntimeError: Providing a bool or integral fill value without setting the optional `dtype` or `out` arguments is currently unsupported. In PyTorch 1.7, when `dtype` and `out` are not set a  bool fill value will return a tensor of torch.bool dtype, and an integral fill value will return a tensor of torch.long dtype.

二、情况分析

  1. 报错的主要原因是:pytorch版本有点低,出现了一些问题。
    我使用的pytorch == 1.6.0,也是属于官方提示的范围内的。一般出现问题去官网看一下,再根据报错的情况具体查一下。

  2. 此次报错情况(一般是最后一行代码出现的问题):

    outs = [torch.full((n, ), i, device=device) for i, n in enumerate(repeats)]

    说明是:full()这个函数有问题。
    并且根据提示:RuntimeError: Providing a bool or integral fill value without setting the optional dtypeorout arguments is currently unsupported.
    基本上就是因为dtype或者 out 这两个参数的问题。

  3. 我们看一下官网(按着自己的版本):

    注意他说的这个警告!当前不支持fill_value在未设置可选dtypeout参数的情况下提供布尔值或整数填充值。 其实就是需要把参数置成dtype=torch.longdtype=torch.bool,需要给参数一个值,不然没有想要的返回值。
    之前还试了一下dtype=torch.float,可是这个并不能解决问题,而且还有了新的问题。参照官方文档,个人感觉这样是不对的。

三、解决方法

方法一

  1. 根据报错的最后一项:
  File "**虚拟环境路径**/miniconda3/envs/spherenet/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 208, in <listcomp>
    outs = [torch.full((n, ), i, device=device) for i, n in enumerate(repeats)]
  1. 根据提示的路径找到这个文件collate.py,找到torch.full()函数。
    注:在linux环境下直接使用/full这个命令就可以了。
  2. 最后直接加上这个参数dtype=torch.long就可以了。在这里插入图片描述
    我的问题出现的情况没有在原文中,在我的代码中没有找到torch.full(),所以修改了源码

方法二

在之前解决方法之前肯定查了好多的方法,比较靠谱的就是下面这个。
参考:参考博文
只需要提供tensor返回的数据类型(dtype=torch.long)就可以了,而不用再在后面加.torch.long()这种方式了。具体加的位置是在torch.full()函数的参数项中即可。

#self.register_buffer('words_to_words ', torch.full((len(vocab),),fill_value=unknown_idx) .long())
#但是因为好像版本问题报错,所以就修改成了下面这样子
self.register_buffer('words_to words ',torch.full(len(vocab),),fill_value=unknown_idx,dtype=torch.long))

版权声明
本文为[pflik-sj]所创,转载请带上原文链接,感谢
https://sunflower.blog.csdn.net/article/details/124360181