Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Overview

Tez: a simple pytorch trainer

NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something doesn't work, please create an issue.

tez (तेज़ / تیز) means sharp, fast & active. This is a simple, to-the-point, library to make your pytorch training easy.

This library is in very early-stage currently! So, there might be breaking changes.

Idea around tez is simple:

  • keep things as simple as possible
  • make it as customizable as possible
  • clean code
  • faster prototyping
  • production ready

Currently, tez supports cpu and gpu training. More coming soon!

Using tez is super-easy. We don't want you to be far away from pytorch. So, you do everything on your own and just use tez to make a few things simpler.

Training using Tez:

  • To train a model, define a dataset and model. The dataset class is the same old class you would write when writing pytorch models.

  • Create your model class. Instead of inheriting from nn.Module, import tez and inherit from tez.Model as shown in the following example.

class MyModel(tez.Model):
    def __init__(self):
        super().__init__()
        .
        .
        # tell when to step the scheduler
        self.step_scheduler_after="batch"

    def monitor_metrics(self, outputs, targets):
        if targets is None:
            return {}
        outputs = torch.sigmoid(outputs).cpu().detach().numpy() >= 0.5
        targets = targets.cpu().detach().numpy()
        accuracy = metrics.accuracy_score(targets, outputs)
        return {"accuracy": accuracy}

    def fetch_scheduler(self):
        # create your own scheduler

    def fetch_optimizer(self):
        # create your own optimizer

    def forward(self, ids, mask, token_type_ids, targets=None):
        _, o_2 = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)
        b_o = self.bert_drop(o_2)
        output = self.out(b_o)

        # calculate loss here
        loss = nn.BCEWithLogitsLoss()(output, targets)

        # calculate the metric dictionary here
        metric_dict = self.monitor_metrics(output, targets)
        return output, loss, metric_dict

Everything is super-intuitive!

  • Now you can train your model!
# init datasets
train_dataset = SomeTrainDataset()
valid_dataset = SomeValidDataset()

# init model
model = MyModel()


# init callbacks, you can also write your own callback
es = tez.callbacks.EarlyStopping(monitor="valid_loss", model_path="model.bin")

# train model. a familiar api!
model.fit(
    train_dataset,
    valid_dataset=valid_dataset,
    train_bs=32,
    device="cuda",
    epochs=50,
    callbacks=[es],
    fp16=True,
)

# save model (with optimizer and scheduler for future!)
model.save("model.bin")

You can checkout examples in examples/

Comments
  • ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)

    ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)

    I am trying to use this package, and it is throwing as below. I am using the same pipeline from cassava lead detection problem but on different set where image size is (256, 256)

    Could you please help here.

    Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b4-6ed6700e.pth 100% 74.4M/74.4M [00:00<00:00, 107MB/s]

    Loaded pretrained weights for efficientnet-b4 0%| | 0/51 [00:00<?, ?it/s]

    ValueError Traceback (most recent call last) in () 11 epochs=10, 12 callbacks=[es], ---> 13 fp16=True, 14 ) 15 model.save("model.bin")

    6 frames /usr/local/lib/python3.6/dist-packages/tez/model/model.py in fit(self, train_dataset, valid_dataset, train_sampler, valid_sampler, device, epochs, train_bs, valid_bs, n_jobs, callbacks, fp16) 295 self.train_state = enums.TrainingState.EPOCH_START 296 self.train_state = enums.TrainingState.TRAIN_EPOCH_START --> 297 train_loss = self.train_one_epoch(self.train_loader, device) 298 self.train_state = enums.TrainingState.TRAIN_EPOCH_END 299 if self.valid_loader:

    /usr/local/lib/python3.6/dist-packages/tez/model/model.py in train_one_epoch(self, data_loader, device) 176 losses = AverageMeter() 177 tk0 = tqdm(data_loader, total=len(data_loader)) --> 178 for b_idx, data in enumerate(tk0): 179 self.train_state = enums.TrainingState.TRAIN_STEP_START 180 loss, metrics = self.train_one_step(data, device)

    /usr/local/lib/python3.6/dist-packages/tqdm/std.py in iter(self) 1102 fp_write=getattr(self.fp, 'write', sys.stderr.write)) 1103 -> 1104 for obj in iterable: 1105 yield obj 1106 # Update and possibly print the progressbar.

    /usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in next(self) 433 if self._sampler_iter is None: 434 self._reset() --> 435 data = self._next_data() 436 self._num_yielded += 1 437 if self._dataset_kind == _DatasetKind.Iterable and \

    /usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self) 1083 else: 1084 del self._task_info[idx] -> 1085 return self._process_data(data) 1086 1087 def _try_put_index(self):

    /usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data) 1109 self._try_put_index() 1110 if isinstance(data, ExceptionWrapper): -> 1111 data.reraise() 1112 return data 1113

    /usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self) 426 # have message field 427 raise self.exc_type(message=msg) --> 428 raise self.exc_type(msg) 429 430

    ValueError: Caught ValueError in DataLoader worker process 0. Original Traceback (most recent call last): File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop data = fetcher.fetch(index) File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in data = [self.dataset[idx] for idx in possibly_batched_index] File "/usr/local/lib/python3.6/dist-packages/tez/datasets/image_classification.py", line 48, in getitem augmented = self.augmentations(image=image) File "/usr/local/lib/python3.6/dist-packages/albumentations/core/composition.py", line 171, in call data = t(**data) File "/usr/local/lib/python3.6/dist-packages/albumentations/core/transforms_interface.py", line 38, in call res[key] = target_function(arg, **dict(params, **target_dependencies)) File "/usr/local/lib/python3.6/dist-packages/albumentations/augmentations/transforms.py", line 808, in apply return F.normalize(image, self.mean, self.std, self.max_pixel_value) File "/usr/local/lib/python3.6/dist-packages/albumentations/augmentations/functional.py", line 93, in normalize img -= mean ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)

    opened by nvnvashisth 10
  • zero_grad for accumulation_steps = 1 not working as expected

    zero_grad for accumulation_steps = 1 not working as expected

    As far as I know, in normal execution flow for zero_grad and forward pass, first we zero_gard for each batch and then do the forward pass but I investigated that in code, it is not happening in this way when accumualtion_steps =1 and batch =1, first forward pass executes first without doing zero_grad.

    I tried to reproduce it and it is doing the same which I explained above.

    image

    Also, I think we can fix this by removing condition in the tez.py file on line # 330, 331.

    opened by abdurrehman11 9
  • Can it work without CUDA

    Can it work without CUDA

    I am getting error when I executed the code with CPU configuration.

    Traceback (most recent call last): File "c:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\recommender.py", line 88, in train() File "c:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\recommender.py", line 82, in train model.fit( File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\tez\model\model.py", line 309, in fit self._init_model( File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\tez\model\model.py", line 93, in _init_model self.to(self.device) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 852, in to return self._apply(convert) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 530, in _apply module._apply(fn) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 552, in apply param_applied = fn(param) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 850, in convert return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\cuda_init.py", line 166, in _lazy_init raise AssertionError("Torch not compiled with CUDA enabled") AssertionError: Torch not compiled with CUDA enabled

    opened by hemanthh17 7
  • Documentation improvement - How is tez faster?

    Documentation improvement - How is tez faster?

    Great to see a nice Pytorch training library.

    I think it would help users use it maybe to show what kind of performance improvements come from the box with Tez. For example comparing how fp16 is enabled in tez vs vanilla pytorch could ben informative or just a quick list of optimisations that are easy to do with Tez such as fp16.

    opened by swartchris8 5
  • Is it possible to set variable Lr per epoch

    Is it possible to set variable Lr per epoch

    @abhishekkrthakur Was finding this framework great and easy to use . But as fairly new to it was thinking if there is a way to pass variable Lr for training say for every epoch as an example.

    Also is there a way to say continue training from a particular epoch if say the local system crashed or got disturbed during the training process.

    opened by gauravbrills 3
  • Applying metrics after the epoch

    Applying metrics after the epoch

    Dears, I am using tez to classify melanoma images (kaggle SIIM binary classification). With wtfml is possible to get AUC ~ 0.85. With tez, I am only getting AUC ~ 0.6. I saw that this happens, in tez, when using metrics.roc_auc_score(...) inside monitor_metrics method. This gives some ValueError exceptions, that must be handled returning auc = 0.5 (this error occurs when the data have only 1 class).

    In the wtfml, the metrics.roc_auc_score(...) method is used only after Engine.evaluate. In this case, the data always have two classes (because the KStratified gives that).

    I am wondering if it is possible, in tez, to apply the metrics.roc_auc_score(...) only after the epoch, and not in each train_bs. With that, the data always will have two classes, avoiding the ValueError exceptions.

    PS.

    1. In the class definition init I am using: self.step_scheduler_after = "epoch" self.step_scheduler_metric = "valid_auc"
    2. In the monitor_metrics method: try: auc = metrics.roc_auc_score(targets, outputs.ravel()) except ValueError: auc = 0.5 return {"auc": auc}
    3. My model.fit is defined as: model.fit(train_dataset, valid_dataset=valid_dataset, train_bs=32, valid_bs=16, device="cuda", epochs=50, callbacks=[es], fp16=False, n_jobs=2)
    opened by waldcarl 2
  • Issue while using Auc metric on imbalanced dataset like melanoma(ValueError: Only one class present in y_true. ROC AUC score is not defined in that case)

    Issue while using Auc metric on imbalanced dataset like melanoma(ValueError: Only one class present in y_true. ROC AUC score is not defined in that case)

    this problem occur due to running metric calculation

    I got the solution from stackoverflow:

    You cannot have an ROC curve without both positive and negative examples in your dataset. With only one class in the dataset, you cannot measure your false-positive rate, and therefore cannot plot an ROC curve. This is why you get this error message.

    How to handle this problem?

    opened by IamSantoshKumar 2
  • Error in Multiclass TypeError: dropout(): argument 'input' (position 1) must be Tensor, not str

    Error in Multiclass TypeError: dropout(): argument 'input' (position 1) must be Tensor, not str

    /usr/local/lib/python3.7/dist-packages/torch/cuda/amp/grad_scaler.py:116: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling. warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") 0%| | 0/2939 [00:00<?, ?it/s]/usr/local/lib/python3.7/dist-packages/torch/cuda/amp/autocast_mode.py:118: UserWarning: torch.cuda.amp.autocast only affects CUDA ops, but CUDA is not available. Disabling. warnings.warn("torch.cuda.amp.autocast only affects CUDA ops, but CUDA is not available. Disabling.")

    TypeError Traceback (most recent call last) in () 143 epochs=3, 144 callbacks=[tb_logger, es], --> 145 fp16=True, 146 ) 147 model.save("model.bin")

    8 frames /usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in dropout(input, p, training, inplace) 1074 if p < 0.0 or p > 1.0: 1075 raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) -> 1076 return VF.dropout(input, p, training) if inplace else _VF.dropout(input, p, training) 1077 1078

    TypeError: dropout(): argument 'input' (position 1) must be Tensor, not str

    opened by gokulguptanew 1
  • Text classification examples - Tokenizer is defined twice

    Text classification examples - Tokenizer is defined twice

    The tokenizer is defined both in the model and the dataset in the BERT text classification examples.

    multi_class.py, line 50: self.tokenizer = transformers.BertTokenizer.from_pretrained( "bert-base-uncased", do_lower_case=True )

    opened by obesp 1
  • Small error in image_classification.py

    Small error in image_classification.py

    If augmentation is None then we face error as , variable augmented referenced before assignment UnboundLocalError: local variable 'augmented' referenced before assignment

    elif self.backend == "cv2":
                image = cv2.imread(self.image_paths[item])
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                if self.resize is not None:
                    image = cv2.resize(
                        image,
                        (self.resize[1], self.resize[0]),
                        interpolation=cv2.INTER_CUBIC,
                    )
                if self.augmentations is not None:
                    augmented = self.augmentations(image=image)
                    image = augmented["image"]
    

    If the indendation is fixed we can solve this error.

    opened by VpkPrasanna 1
  • Small error in model.py

    Small error in model.py

    Hi! Love this library.
    In tez/model/model.py there is probably a mistake in line 90:

    self.train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=train_bs,
                    num_workers=n_jobs,
                    sampler=valid_sampler,
                    shuffle=True,
                )
    

    I guess train_sampler is meant to be used here, not valid_sampler.

    opened by hocop 1
  • run example code error

    run example code error

    when I run example code:

    accelerate launch   imdb_sentiment_classification.py
    

    after run some epoch get error info

    INFO:tez.callbacks.early_stopping:EarlyStopping counter: 4/5
    [train] accuracy=0.9915, loss=0.0269 [valid] accuracy=0.8953, loss=0.4287 [e=5 steps=2112]                                                                                                 
     30%|████████████████████████████████▍                                                                           | 2112/7040 [05:45<06:40, 12.32it/s, accuracy=0.991, epoch=5, loss=0.0269]2022-09-17 07:55:02,832 INFO EarlyStopping counter: 5/5
    INFO:tez.callbacks.early_stopping:EarlyStopping counter: 5/5
     30%|████████████████████████████████▍                                                                           | 2112/7040 [05:47<13:31,  6.07it/s, accuracy=0.991, epoch=5, loss=0.0269]
    
    
    
    
    [E ProcessGroupNCCL.cpp:719] [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808970 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:719] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808984 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:719] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1809275 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
    terminate called after throwing an instance of 'std::runtime_error'
      what():  [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1809275 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
    terminate called after throwing an instance of 'std::runtime_error'
      what():  [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808970 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
    terminate called after throwing an instance of 'std::runtime_error'
      what():  [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808984 milliseconds before timing out.
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/resource_sharer.py", line 138, in _serve
        with self._listener.accept() as conn:
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 470, in accept
        deliver_challenge(c, self._authkey)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 745, in deliver_challenge
        response = connection.recv_bytes(256)        # reject large message
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 221, in recv_bytes
        buf = self._recv_bytes(maxlength)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 419, in _recv_bytes
        buf = self._recv(4)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 384, in _recv
        chunk = read(handle, remaining)
    ConnectionResetError: [Errno 104] Connection reset by peer
    WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 113654 closing signal SIGTERM
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/resource_sharer.py", line 138, in _serve
        with self._listener.accept() as conn:
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 470, in accept
        deliver_challenge(c, self._authkey)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 745, in deliver_challenge
        response = connection.recv_bytes(256)        # reject large message
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 221, in recv_bytes
        buf = self._recv_bytes(maxlength)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 419, in _recv_bytes
        buf = self._recv(4)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 384, in _recv
        chunk = read(handle, remaining)
    ConnectionResetError: [Errno 104] Connection reset by peer
    ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 2 (pid: 113655) of binary: /root/miniconda3/envs/lightning/bin/python
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/bin/torchrun", line 33, in <module>
        sys.exit(load_entry_point('torch==1.11.0', 'console_scripts', 'torchrun')())
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
        return f(*args, **kwargs)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/run.py", line 724, in main
        run(args)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/run.py", line 715, in run
        elastic_launch(
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
        return launch_agent(self._config, self._entrypoint, list(args))
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent
        raise ChildFailedError(
    torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
    =======================================================
    imdb_sentiment_classification.py FAILED
    -------------------------------------------------------
    Failures:
    [1]:
      time      : 2022-09-17_08:25:22
      host      : dy-a100-779-tlzrv
      rank      : 3 (local_rank: 3)
      exitcode  : -6 (pid: 113656)
      error_file: <N/A>
      traceback : Signal 6 (SIGABRT) received by PID 113656
    -------------------------------------------------------
    Root Cause (first observed failure):
    [0]:
      time      : 2022-09-17_08:25:22
      host      : dy-a100-779-tlzrv
      rank      : 2 (local_rank: 2)
      exitcode  : -6 (pid: 113655)
      error_file: <N/A>
      traceback : Signal 6 (SIGABRT) received by PID 113655
    =======================================================
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/bin/accelerate", line 33, in <module>
        sys.exit(load_entry_point('accelerate==0.12.0.dev0', 'console_scripts', 'accelerate')())
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/accelerate-0.12.0.dev0-py3.9.egg/accelerate/commands/accelerate_cli.py", line 43, in main
        args.func(args)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/accelerate-0.12.0.dev0-py3.9.egg/accelerate/commands/launch.py", line 734, in launch_command
        multi_gpu_launcher(args)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/accelerate-0.12.0.dev0-py3.9.egg/accelerate/commands/launch.py", line 374, in multi_gpu_launcher
        raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
    subprocess.CalledProcessError: Command '['torchrun', '--nproc_per_node', '4', 'imdb_sentiment_classification.py']' returned non-zero exit status 1.
    
    opened by bestpredicts 0
  • Getting error while importing enums from tez.

    Getting error while importing enums from tez.

    Traceback (most recent call last): File "/content/tez/tez/model/model.py", line 12, in from tez import enums File "/content/tez/tez/model/tez.py", line 11, in from tez import enums ImportError: cannot import name 'enums' from 'tez' (/content/tez/tez/model/tez.py)

    Waiting for positive reply.

    opened by VikasRathod314 3
  • Saving validation score

    Saving validation score

    Is it possible to save somehow a list of the validation scores (on epochs or batches) after training? I have some problems with output on my server, it deletes usually, but I really need validation scores to compare models, it would be really convenient, if I could get them in one file, for example.

    opened by 25icecreamflavors 0
  • Saving after training an epoch

    Saving after training an epoch

    How to save the model after each epoch training? I use fit method for 5 epochs and do not really understand hot to save after each one. not only after the last one.

    opened by 25icecreamflavors 2
  • How can we access the input_ids/attention mask in each train batch loop?

    How can we access the input_ids/attention mask in each train batch loop?

    I tried using a train step callback but I am not sure how to get access to the dataloader input_ids and attention mask during each train step. Is this possible?

    BTW Thanks for the library!

    opened by tkmaker 0
Releases(v0.1.8)
Owner
abhishek thakur
Kaggle: www.kaggle.com/abhishek
abhishek thakur
A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch

Torchmeta A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch. Torchmeta contains popular meta-learning bench

Tristan Deleu 1.7k Jan 06, 2023
A tiny scalar-valued autograd engine and a neural net library on top of it with PyTorch-like API

micrograd A tiny Autograd engine (with a bite! :)). Implements backpropagation (reverse-mode autodiff) over a dynamically built DAG and a small neural

Andrej 3.5k Jan 08, 2023
Pytorch bindings for Fortran

Pytorch bindings for Fortran

Dmitry Alexeev 46 Dec 29, 2022
Implements pytorch code for the Accelerated SGD algorithm.

AccSGD This is the code associated with Accelerated SGD algorithm used in the paper On the insufficiency of existing momentum schemes for Stochastic O

205 Jan 02, 2023
The goal of this library is to generate more helpful exception messages for numpy/pytorch matrix algebra expressions.

Tensor Sensor See article Clarifying exceptions and visualizing tensor operations in deep learning code. One of the biggest challenges when writing co

Terence Parr 704 Dec 14, 2022
A simplified framework and utilities for PyTorch

Here is Poutyne. Poutyne is a simplified framework for PyTorch and handles much of the boilerplating code needed to train neural networks. Use Poutyne

GRAAL/GRAIL 534 Dec 17, 2022
Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Tez: a simple pytorch trainer NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something does

abhishek thakur 1.1k Jan 04, 2023
PyGCL: Graph Contrastive Learning Library for PyTorch

PyGCL is an open-source library for graph contrastive learning (GCL), which features modularized GCL components from published papers, standardized evaluation, and experiment management.

GCL: Graph Contrastive Learning Library for PyTorch 592 Jan 07, 2023
Learning Sparse Neural Networks through L0 regularization

Example implementation of the L0 regularization method described at Learning Sparse Neural Networks through L0 regularization, Christos Louizos, Max W

AMLAB 202 Nov 10, 2022
A PyTorch implementation of L-BFGS.

PyTorch-LBFGS: A PyTorch Implementation of L-BFGS Authors: Hao-Jun Michael Shi (Northwestern University) and Dheevatsa Mudigere (Facebook) What is it?

Hao-Jun Michael Shi 478 Dec 27, 2022
Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute

Lambda Networks - Pytorch Implementation of λ Networks, a new approach to image recognition that reaches SOTA on ImageNet. The new method utilizes λ l

Phil Wang 1.5k Jan 07, 2023
PyTorch extensions for fast R&D prototyping and Kaggle farming

Pytorch-toolbelt A pytorch-toolbelt is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming: What

Eugene Khvedchenya 1.3k Jan 05, 2023
A code copied from google-research which named motion-imitation was rewrited with PyTorch

motor-system Introduction A code copied from google-research which named motion-imitation was rewrited with PyTorch. More details can get from this pr

NewEra 6 Jan 08, 2022
higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.

higher is a library providing support for higher-order optimization, e.g. through unrolled first-order optimization loops, of "meta" aspects of these

Facebook Research 1.5k Jan 03, 2023
Distiller is an open-source Python package for neural network compression research.

Wiki and tutorials | Documentation | Getting Started | Algorithms | Design | FAQ Distiller is an open-source Python package for neural network compres

Intel Labs 4.1k Dec 28, 2022
Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Martin Krasser 251 Dec 25, 2022
torch-optimizer -- collection of optimizers for Pytorch

torch-optimizer torch-optimizer -- collection of optimizers for PyTorch compatible with optim module. Simple example import torch_optimizer as optim

Nikolay Novik 2.6k Jan 03, 2023
PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

glow-pytorch PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

Kim Seonghyeon 433 Dec 27, 2022
PyTorch Extension Library of Optimized Scatter Operations

PyTorch Scatter Documentation This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations fo

Matthias Fey 1.2k Jan 07, 2023
The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

News March 3: v0.9.97 has various bug fixes and improvements: Bug fixes for NTXentLoss Efficiency improvement for AccuracyCalculator, by using torch i

Kevin Musgrave 5k Jan 02, 2023