A best practice for tensorflow project template architecture.

Overview

Tensorflow Project Template

A simple and well designed structure is essential for any Deep Learning project, so after a lot of practice and contributing in tensorflow projects here's a tensorflow project template that combines simplcity, best practice for folder structure and good OOP design. The main idea is that there's much stuff you do every time you start your tensorflow project, so wrapping all this shared stuff will help you to change just the core idea every time you start a new tensorflow project.

So, here's a simple tensorflow template that help you get into your main project faster and just focus on your core (Model, Training, ...etc)

Table Of Contents

In a Nutshell

In a nutshell here's how to use this template, so for example assume you want to implement VGG model so you should do the following:

  • In models folder create a class named VGG that inherit the "base_model" class
    class VGGModel(BaseModel):
        def __init__(self, config):
            super(VGGModel, self).__init__(config)
            #call the build_model and init_saver functions.
            self.build_model() 
            self.init_saver() 
  • Override these two functions "build_model" where you implement the vgg model, and "init_saver" where you define a tensorflow saver, then call them in the initalizer.
     def build_model(self):
        # here you build the tensorflow graph of any model you want and also define the loss.
        pass
            
     def init_saver(self):
        # here you initalize the tensorflow saver that will be used in saving the checkpoints.
        self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)
  • In trainers folder create a VGG trainer that inherit from "base_train" class
    class VGGTrainer(BaseTrain):
        def __init__(self, sess, model, data, config, logger):
            super(VGGTrainer, self).__init__(sess, model, data, config, logger)
  • Override these two functions "train_step", "train_epoch" where you write the logic of the training process
    def train_epoch(self):
        """
       implement the logic of epoch:
       -loop on the number of iterations in the config and call the train step
       -add any summaries you want using the summary
        """
        pass

    def train_step(self):
        """
       implement the logic of the train step
       - run the tensorflow session
       - return any metrics you need to summarize
       """
        pass
  • In main file, you create the session and instances of the following objects "Model", "Logger", "Data_Generator", "Trainer", and config
    sess = tf.Session()
    # create instance of the model you want
    model = VGGModel(config)
    # create your data generator
    data = DataGenerator(config)
    # create tensorboard logger
    logger = Logger(sess, config)
  • Pass the all these objects to the trainer object, and start your training by calling "trainer.train()"
    trainer = VGGTrainer(sess, model, data, config, logger)

    # here you train your model
    trainer.train()

You will find a template file and a simple example in the model and trainer folder that shows you how to try your first model simply.

In Details

Project architecture

Folder structure

├──  base
│   ├── base_model.py   - this file contains the abstract class of the model.
│   └── base_train.py   - this file contains the abstract class of the trainer.
│
│
├── model               - this folder contains any model of your project.
│   └── example_model.py
│
│
├── trainer             - this folder contains trainers of your project.
│   └── example_trainer.py
│   
├──  mains              - here's the main(s) of your project (you may need more than one main).
│    └── example_main.py  - here's an example of main that is responsible for the whole pipeline.

│  
├──  data _loader  
│    └── data_generator.py  - here's the data_generator that is responsible for all data handling.
│ 
└── utils
     ├── logger.py
     └── any_other_utils_you_need

Main Components

Models


  • Base model

    Base model is an abstract class that must be Inherited by any model you create, the idea behind this is that there's much shared stuff between all models. The base model contains:

    • Save -This function to save a checkpoint to the desk.
    • Load -This function to load a checkpoint from the desk.
    • Cur_epoch, Global_step counters -These variables to keep track of the current epoch and global step.
    • Init_Saver An abstract function to initialize the saver used for saving and loading the checkpoint, Note: override this function in the model you want to implement.
    • Build_model Here's an abstract function to define the model, Note: override this function in the model you want to implement.
  • Your model

    Here's where you implement your model. So you should :

    • Create your model class and inherit the base_model class
    • override "build_model" where you write the tensorflow model you want
    • override "init_save" where you create a tensorflow saver to use it to save and load checkpoint
    • call the "build_model" and "init_saver" in the initializer.

Trainer


  • Base trainer

    Base trainer is an abstract class that just wrap the training process.

  • Your trainer

    Here's what you should implement in your trainer.

    1. Create your trainer class and inherit the base_trainer class.
    2. override these two functions "train_step", "train_epoch" where you implement the training process of each step and each epoch.

Data Loader

This class is responsible for all data handling and processing and provide an easy interface that can be used by the trainer.

Logger

This class is responsible for the tensorboard summary, in your trainer create a dictionary of all tensorflow variables you want to summarize then pass this dictionary to logger.summarize().

This class also supports reporting to Comet.ml which allows you to see all your hyper-params, metrics, graphs, dependencies and more including real-time metric. Add your API key in the configuration file:

For example: "comet_api_key": "your key here"

Comet.ml Integration

This template also supports reporting to Comet.ml which allows you to see all your hyper-params, metrics, graphs, dependencies and more including real-time metric.

Add your API key in the configuration file:

For example: "comet_api_key": "your key here"

Here's how it looks after you start training:

You can also link your Github repository to your comet.ml project for full version control. Here's a live page showing the example from this repo

Configuration

I use Json as configuration method and then parse it, so write all configs you want then parse it using "utils/config/process_config" and pass this configuration object to all other objects.

Main

Here's where you combine all previous part.

  1. Parse the config file.
  2. Create a tensorflow session.
  3. Create an instance of "Model", "Data_Generator" and "Logger" and parse the config to all of them.
  4. Create an instance of "Trainer" and pass all previous objects to it.
  5. Now you can train your model by calling "Trainer.train()"

Future Work

  • Replace the data loader part with new tensorflow dataset API.

Contributing

Any kind of enhancement or contribution is welcomed.

Acknowledgments

Thanks for my colleague Mo'men Abdelrazek for contributing in this work. and thanks for Mohamed Zahran for the review. Thanks for Jtoy for including the repo in Awesome Tensorflow.

Owner
Mahmoud Gamal Salem
MSc. in AI at university of Guelph and Vector Institute. AI intern @samsung
Mahmoud Gamal Salem
Video Representation Learning by Recognizing Temporal Transformations. In ECCV, 2020.

Video Representation Learning by Recognizing Temporal Transformations [Project Page] Simon Jenni, Givi Meishvili, and Paolo Favaro. In ECCV, 2020. Thi

Simon Jenni 46 Nov 14, 2022
A Pytorch Implementation of Source Data-free Domain Adaptation for a Faster R-CNN

A Pytorch Implementation of Source Data-free Domain Adaptation for a Faster R-CNN Please follow Faster R-CNN and DAF to complete the environment confi

2 Jan 12, 2022
GAN Image Generator and Characterwise Image Recognizer with python

MODEL SUMMARY 모델의 구조는 크게 6단계로 나뉩니다. STEP 0: Input Image Predict 할 이미지를 모델에 입력합니다. STEP 1: Make Black and White Image STEP 1 은 입력받은 이미지의 글자를 흑색으로, 배경을

Juwan HAN 1 Feb 09, 2022
Experimental code for paper: Generative Adversarial Networks as Variational Training of Energy Based Models

Experimental code for paper: Generative Adversarial Networks as Variational Training of Energy Based Models, under review at ICLR 2017 requirements: T

Shuangfei Zhai 18 Mar 05, 2022
Mixed Transformer UNet for Medical Image Segmentation

MT-UNet Update 2021/11/19 Thank you for your interest in our work. We have uploaded the code of our MTUNet to help peers conduct further research on i

dotman 92 Dec 25, 2022
Face Recognition and Emotion Detector Device

Face Recognition and Emotion Detector Device Orange PI 1 Python 3.10.0 + Django 3.2.9 Project's file explanation Django manage.py Django commands hand

BootyAss 2 Dec 21, 2021
Code for EMNLP'21 paper "Types of Out-of-Distribution Texts and How to Detect Them"

ood-text-emnlp Code for EMNLP'21 paper "Types of Out-of-Distribution Texts and How to Detect Them" Files fine_tune.py is used to finetune the GPT-2 mo

Udit Arora 19 Oct 28, 2022
[NeurIPS-2021] Slow Learning and Fast Inference: Efficient Graph Similarity Computation via Knowledge Distillation

Efficient Graph Similarity Computation - (EGSC) This repo contains the source code and dataset for our paper: Slow Learning and Fast Inference: Effici

24 Dec 31, 2022
It is a simple library to speed up CLIP inference up to 3x (K80 GPU)

CLIP-ONNX It is a simple library to speed up CLIP inference up to 3x (K80 GPU) Usage Install clip-onnx module and requirements first. Use this trick !

Gerasimov Maxim 93 Dec 20, 2022
HMLET (Hybrid-Method-of-Linear-and-non-linEar-collaborative-filTering-method)

Methods HMLET (Hybrid-Method-of-Linear-and-non-linEar-collaborative-filTering-method) Dynamically selecting the best propagation method for each node

Yong 7 Dec 18, 2022
OpenMMLab Text Detection, Recognition and Understanding Toolbox

Introduction English | 简体中文 MMOCR is an open-source toolbox based on PyTorch and mmdetection for text detection, text recognition, and the correspondi

OpenMMLab 3k Jan 07, 2023
Meta-learning for NLP

Self-Supervised Meta-Learning for Few-Shot Natural Language Classification Tasks Code for training the meta-learning models and fine-tuning on downstr

IESL 43 Nov 08, 2022
Graph WaveNet apdapted for brain connectivity analysis.

Graph WaveNet for brain network analysis This is the implementation of the Graph WaveNet model used in our manuscript: S. Wein , A. Schüller, A. M. To

4 Dec 17, 2022
Library for time-series-forecasting-as-a-service.

TIMEX TIMEX (referred in code as timexseries) is a framework for time-series-forecasting-as-a-service. Its main goal is to provide a simple and generi

Alessandro Falcetta 8 Jan 06, 2023
Generic image compressor for machine learning. Pytorch code for our paper "Lossy compression for lossless prediction".

Lossy Compression for Lossless Prediction Using: Training: This repostiory contains our implementation of the paper: Lossy Compression for Lossless Pr

Yann Dubois 84 Jan 02, 2023
Matthew Colbrook 1 Apr 08, 2022
Pyramid addon for OpenAPI3 validation of requests and responses.

Validate Pyramid views against an OpenAPI 3.0 document Peace of Mind The reason this package exists is to give you peace of mind when providing a REST

Pylons Project 79 Dec 30, 2022
Open CV - Convert a picture to look like a cartoon sketch in python

Use the video https://www.youtube.com/watch?v=k7cVPGpnels for initial learning.

Sammith S Bharadwaj 3 Jan 29, 2022
VR-Caps: A Virtual Environment for Active Capsule Endoscopy

VR-Caps: A Virtual Environment for Capsule Endoscopy Overview We introduce a virtual active capsule endoscopy environment developed in Unity that prov

DeepMIA Lab 90 Dec 27, 2022
Aesara is a Python library that allows one to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.

Aesara is a Python library that allows one to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.

Aesara 898 Jan 07, 2023