Classifies galaxy morphology with Bayesian CNN

Related tags

Deep Learningzoobot
Overview

Zoobot

Documentation Status

Zoobot classifies galaxy morphology with deep learning. This code will let you:

  • Reproduce and improve the Galaxy Zoo DECaLS automated classifications
  • Finetune the classifier for new tasks

For example, you can train a new classifier like so:

model = define_model.get_model(
    output_dim=len(schema.label_cols),  # schema defines the questions and answers
    input_size=initial_size, 
    crop_size=int(initial_size * 0.75),
    resize_size=resize_size
)

model.compile(
    loss=losses.get_multiquestion_loss(schema.question_index_groups),
    optimizer=tf.keras.optimizers.Adam()
)

training_config.train_estimator(
    model, 
    train_config,  # parameters for how to train e.g. epochs, patience
    train_dataset,
    test_dataset
)

Install using git and pip: git clone [email protected]:mwalmsley/zoobot.git pip install -r zoobot/requirements.txt (virtual env or conda highly recommended) pip install -e zoobot The main branch is for stable-ish releases. The dev branch includes the shiniest features but may change at any time.

To get started, see the documentation.

I also include some working examples for you to copy and adapt:

Latest cool features on dev branch (June 2021):

  • Multi-GPU distributed training
  • Support for Weights and Biases (wandb)
  • Worked examples for custom representations

Contributions are welcome and will be credited in any future work.

If you use this repo for your research, please cite the paper.

Comments
  • Benchmarks

    Benchmarks

    It's important that Zoobot has proper benchmarks so that we can be confident new releases work properly for users. This PR adds those benchmarks.

    In the course of setting up the benchmarks, I have made some major changes/improvements:

    • pytorch-galaxy-datasets refactored to work for tensorflow, imports adapted
    • both tensorflow and pytorch zoobot versions use albumentations for augmentations. Old TF code removed.
    • tensorflow version bumped to 2.10 (current latest) while I'm at it
    • pytorch version now has logging for per-question loss. Loss func aggregation has new option to support this.
    • TensorFlow version has per-question logging also, but awaiting issue with Keras team to enable
    • Created minimal_example.py for TensorFlow (thanks, @katgre )
    • Support CPU-only PyTorch training
    • Refactor TF TrainingConfig to Trainer object, Lightning style, for consistency
    enhancement 
    opened by mwalmsley 3
  • on_train_batch_end is slow in TF

    on_train_batch_end is slow in TF

    Unclear what's causing this slowness. Presumably a callback I added - but none look like they should be heavy? Perhaps something wandb is doing?

    • Remove all callbacks and rerun
    • Remove wandb and rerun For each, check if slow warning continues (or if training speed changes at all)
    enhancement 
    opened by mwalmsley 3
  • add gh action to publish package to pypi

    add gh action to publish package to pypi

    Related to https://github.com/mwalmsley/zoobot/issues/18#issuecomment-1278635788

    This PR adds an auto CI release mechanism for publishing zoobot to pypi. It uses the GH action to release to pypi https://github.com/pypa/gh-action-pypi-publish

    opened by camallen 3
  • Publish latest version to PyPi?

    Publish latest version to PyPi?

    A question rather than a request. Are there any plans to publish the refactored work ?

    PyPi shows v0.0.1 is published https://pypi.org/project/zoobot/#history on 15th March 2021 but the latest code is ~v0.0.3 (tags) and the refactor seems to be working well.

    Ideally I can pull in these packages to my own env / container and then train with the latest code vs pulling in from github etc.

    opened by camallen 3
  • setup branch protection rules on 'main'

    setup branch protection rules on 'main'

    https://docs.github.com/en/repositories/configuring-branches-and-merges-in-your-repository/defining-the-mergeability-of-pull-requests/managing-a-branch-protection-rule

    It may be too restrictive for your use case / dev flows but we use this for contributor PRs etc. Basically we ensure that a PR meets certain criteria in terms of our CI runs, can only merge a PR once one of the CI runs v3.7 or v3.9 tests pass.

    Feel free to close if you don't think this is useful.

    enhancement 
    opened by camallen 2
  • Deprecate TFRecords

    Deprecate TFRecords

    TFRecords are cumbersome and take up a lot of disk space. It's much simpler to learn directly from images on disk, at the cost of some I/O performance.

    This PR removes support for TFRecords in favour of images-on-disk. This will ultimately enable new TensorFlow weights trained on all of DESI (impractical with TFRecords).

    Breaking change for anyone using TFRecords (i.e. everyone using TensorFlow to train from scratch). Finetuning should not be affected.

    TODO - will require new greyscale/colour pretrained models, just for safety.

    opened by mwalmsley 2
  • feat(CI): Add proposed python CI GH Action

    feat(CI): Add proposed python CI GH Action

    This PR proposes to add a simple GH Action script that establishes a python environment, downloads the requirements and runs pytest.

    Some other things to consider might be to use conda for virtual environments and creating CI scripts for Docker as well.

    opened by SauravMaheshkar 2
  • Improve data files for docker

    Improve data files for docker

    This PR changes the docker / compose setup, specifically it

    • consolidates the docker files to cuda and tensorflow base images (no need for a python base image)
    • adds a .dockerignore entry for all data files when building the container to keep the size down
    • and provides an easy way to inject them at run time via local directory mounts in the compose file
    • finally this removes specific to my machine local directory setup for injecting unrelated data files
    opened by camallen 2
  • add wandb logging, freeze batchnorm by default

    add wandb logging, freeze batchnorm by default

    Doing some polishing on finetuning

    • Add wandb logging to the full_tree example. @camallen use this for dashboard. You will need to add import wandb, wandb.init(authkey, etc) just before when running on Azure.
    • Freeze batch norm layers by default when finetuning, with new recursive function
    • Pass additional params via config (thanks Cam)
    • Minor cleanup
    opened by mwalmsley 1
  • Add PyTorch Finetuning Capability, Examples

    Add PyTorch Finetuning Capability, Examples

    Key change is adding pytorch.training.finetune() method. Works on either classification (e.g. 0, 1) data or count (e.g. 12 said yes, 4 said no) data.

    Includes three working examples:

    • Binary classification, with tiny rings subset
    • Counts for single question, with full internal rings data
    • Counts for all questions, with GZ Cosmic Dawn schema

    Also updates various imports for the galaxy-datasets refactor, fixes prediction method to work on unlabelled data, minor QoL improvements.

    Finally, changes PyTorch dense layer initialisation to custom high-uncertainty initialisation - see efficientnet_custom.py

    cc @camallen

    opened by mwalmsley 1
  • Add v0.02 changes

    Add v0.02 changes

    Adds support (minimal working examples, a guide) for calculating new representations with a trained model.

    Also adds significant new features:

    • Distributed training with several GPUs
    • Metric logging with Weights&Biases (add your own login credentials)
    • Train on color (3-band) images, not just greyscale

    Also adds a critical bugfix (when loading images for direct predictions i.e. not via TFRecords, correctly normalise to the 0-1 interval expected (without documentation) by the tf.keras.experimental.preprocessing layers).

    Also adds misc. minor fixes and documentation tweaks.

    This code was used for the morphology tools paper (to be submitted shortly).

    opened by mwalmsley 1
  • Avoid --extra-index-url via dependency_links

    Avoid --extra-index-url via dependency_links

    It should be possible to search for non-standard package repositories using just setup.py, without having the user also set --extra-index-url.

    https://setuptools.pypa.io/en/latest/deprecated/dependency_links.html

    But I couldn't get this to work on a quick try.

    enhancement help wanted 
    opened by mwalmsley 1
  • Can't import finetune while going through finetune_binary_classification.py

    Can't import finetune while going through finetune_binary_classification.py

    I tried to go through finetune_binary_classification.py, but got the error:

    ImportError: cannot import name 'finetune' from 'zoobot.pytorch.training' (/usr/local/lib/python3.8/dist-packages/zoobot/pytorch/training/init.py)

    I tried it both with kasia and dev branch, went through "git clone" and "pip install" (I remembered there were some issues during Hackaton regarding that), also tried to import other features from the folder (i.e. losses) and it worked fine.

    bug 
    opened by katgre 2
  • Create a simple decision tree in minimal_example.py

    Create a simple decision tree in minimal_example.py

    Instead of using on of the complicated decision trees from decals dr5, let's create a simple decision tree with one dependency already written in the minimal_example.py.

    opened by katgre 0
Releases(v0.0.3)
  • v0.0.3(Apr 25, 2022)

    Improved documentation and refactored train API (pytorch).

    Awaiting results from several segmentation experiments ahead of public release (inc pytorch version).

    Source code(tar.gz)
    Source code(zip)
  • v0.0.2(Oct 4, 2021)

  • beta(Sep 29, 2021)

    Initial release.

    This had enough documentation and code to replicate the DECaLS model and make predictions. There are a few minor missing arguments and similar typos that you might have stumbled into, because I made some last minute changes without updating the docs, but everything worked with a little stack tracing.

    Source code(tar.gz)
    Source code(zip)
Owner
Mike Walmsley
Mike Walmsley
PyTorch implementations for our SIGGRAPH 2021 paper: Editable Free-viewpoint Video Using a Layered Neural Representation.

st-nerf We provide PyTorch implementations for our paper: Editable Free-viewpoint Video Using a Layered Neural Representation SIGGRAPH 2021 Jiakai Zha

Diplodocus 258 Jan 02, 2023
NAACL'2021: Factual Probing Is [MASK]: Learning vs. Learning to Recall

OptiPrompt This is the PyTorch implementation of the paper Factual Probing Is [MASK]: Learning vs. Learning to Recall. We propose OptiPrompt, a simple

Princeton Natural Language Processing 150 Dec 20, 2022
source code of “Visual Saliency Transformer” (ICCV2021)

Visual Saliency Transformer (VST) source code for our ICCV 2021 paper “Visual Saliency Transformer” by Nian Liu, Ni Zhang, Kaiyuan Wan, Junwei Han, an

89 Dec 21, 2022
Constrained Logistic Regression - How to apply specific constraints to logistic regression's coefficients

Constrained Logistic Regression Sample implementation of constructing a logistic regression with given ranges on each of the feature's coefficients (v

1 Dec 29, 2021
SLAMP: Stochastic Latent Appearance and Motion Prediction

SLAMP: Stochastic Latent Appearance and Motion Prediction Official implementation of the paper SLAMP: Stochastic Latent Appearance and Motion Predicti

Kaan Akan 34 Dec 08, 2022
Simple converter for deploying Stable-Baselines3 model to TFLite and/or Coral

Running SB3 developed agents on TFLite or Coral Introduction I've been using Stable-Baselines3 to train agents against some custom Gyms, some of which

Gary Briggs 16 Oct 11, 2022
A Novel Plug-in Module for Fine-grained Visual Classification

Pytorch implementation for A Novel Plug-in Module for Fine-Grained Visual Classification. fine-grained visual classification task.

ChouPoYung 109 Dec 20, 2022
Source code for Task-Aware Variational Adversarial Active Learning

Contrastive Coding for Active Learning under Class Distribution Mismatch Official PyTorch implementation of ["Contrastive Coding for Active Learning u

27 Nov 23, 2022
Decision Transformer: A brand new Offline RL Pattern

DecisionTransformer_StepbyStep Intro Decision Transformer: A brand new Offline RL Pattern. 这是关于NeurIPS 2021 热门论文Decision Transformer的复现。 👍 原文地址: Deci

Irving 14 Nov 22, 2022
A simple baseline for 3d human pose estimation in tensorflow. Presented at ICCV 17.

3d-pose-baseline This is the code for the paper Julieta Martinez, Rayat Hossain, Javier Romero, James J. Little. A simple yet effective baseline for 3

Julieta Martinez 1.3k Jan 03, 2023
The source code for CATSETMAT: Cross Attention for Set Matching in Bipartite Hypergraphs

catsetmat The source code for CATSETMAT: Cross Attention for Set Matching in Bipartite Hypergraphs To be able to run it, add catsetmat to PYTHONPATH H

2 Dec 19, 2022
Fast and robust clustering of point clouds generated with a Velodyne sensor.

Depth Clustering This is a fast and robust algorithm to segment point clouds taken with Velodyne sensor into objects. It works with all available Velo

Photogrammetry & Robotics Bonn 957 Dec 21, 2022
Pytorch implementation of the paper: "A Unified Framework for Separating Superimposed Images", in CVPR 2020.

Deep Adversarial Decomposition PDF | Supp | 1min-DemoVideo Pytorch implementation of the paper: "Deep Adversarial Decomposition: A Unified Framework f

Zhengxia Zou 72 Dec 18, 2022
PoseViz – Multi-person, multi-camera 3D human pose visualization tool built using Mayavi.

PoseViz – 3D Human Pose Visualizer Multi-person, multi-camera 3D human pose visualization tool built using Mayavi. As used in MeTRAbs visualizations.

István Sárándi 79 Dec 30, 2022
Segment axon and myelin from microscopy data using deep learning

Segment axon and myelin from microscopy data using deep learning. Written in Python. Using the TensorFlow framework. Based on a convolutional neural network architecture. Pixels are classified as eit

NeuroPoly 103 Nov 29, 2022
Justmagic - Use a function as a method with this mystic script, like in Nim

justmagic Use a function as a method with this mystic script, like in Nim. Just

witer33 8 Oct 08, 2022
The official implementation of ELSA: Enhanced Local Self-Attention for Vision Transformer

ELSA: Enhanced Local Self-Attention for Vision Transformer By Jingkai Zhou, Pich

DamoCV 87 Dec 19, 2022
A Pytorch implementation of "Splitter: Learning Node Representations that Capture Multiple Social Contexts" (WWW 2019).

Splitter ⠀⠀ A PyTorch implementation of Splitter: Learning Node Representations that Capture Multiple Social Contexts (WWW 2019). Abstract Recent inte

Benedek Rozemberczki 201 Nov 09, 2022
A curated list of Generative Deep Art projects, tools, artworks, and models

Generative Deep Art A curated list of Generative Deep Art projects, tools, artworks, and models Inbox Get started with making AI art in 2022 – deeplea

Filipe Calegario 251 Jan 03, 2023
Planar Prior Assisted PatchMatch Multi-View Stereo

ACMP [News] The code for ACMH is released!!! [News] The code for ACMM is released!!! About This repository contains the code for the paper Planar Prio

Qingshan Xu 127 Dec 31, 2022