Official PyTorch Implementation of "AgentFormer: Agent-Aware Transformers for Socio-Temporal Multi-Agent Forecasting".

Overview

AgentFormer

This repo contains the official implementation of our paper:

AgentFormer: Agent-Aware Transformers for Socio-Temporal Multi-Agent Forecasting
Ye Yuan, Xinshuo Weng, Yanglan Ou, Kris Kitani
ICCV 2021
[website] [paper]

Overview

Loading AgentFormer Overview

Important Note

We have recently noticed a normalization bug in the code and after fixing it, the performance of our method is worse than the original numbers reported in the ICCV paper. For comparision, please use the correct numbers in the updated arXiv version.

Installation

Environment

  • Tested OS: MacOS, Linux
  • Python >= 3.7
  • PyTorch == 1.8.0

Dependencies:

  1. Install PyTorch 1.8.0 with the correct CUDA version.
  2. Install the dependencies:
    pip install -r requirements.txt
    

Datasets

  • For the ETH/UCY dataset, we already included a converted version compatible with our dataloader under datasets/eth_ucy.
  • For the nuScenes dataset, the following steps are required:
    1. Download the orignal nuScenes dataset. Checkout the instructions here.
    2. Follow the instructions of nuScenes prediction challenge. Download and install the map expansion.
    3. Run our script to obtain a processed version of the nuScenes dataset under datasets/nuscenes_pred:
      python data/process_nuscenes.py --data_root <PATH_TO_NUSCENES>
      

Pretrained Models

  • You can download pretrained models from Google Drive or BaiduYun (password: 9rvb) to reproduce the numbers in the paper.
  • Once the agentformer_models.zip file is downloaded, place it under the root folder of this repo and unzip it:
    unzip agentformer_models.zip
    
    This will place the models under the results folder. Note that the pretrained models directly correspond to the config files in cfg.

Evaluation

ETH/UCY

Run the following command to test pretrained models for the ETH dataset:

python test.py --cfg eth_agentformer --gpu 0

You can replace eth with {hotel, univ, zara1, zara2} to test other datasets in ETH/UCY. You should be able to get the numbers reported in the paper as shown in this table:

Ours ADE FDE
ETH 0.45 0.75
Hotel 0.14 0.22
Univ 0.25 0.45
Zara1 0.18 0.30
Zara2 0.14 0.24
Avg 0.23 0.39

nuScenes

Run the following command to test pretrained models for the nuScenes dataset:

python test.py --cfg nuscenes_5sample_agentformer --gpu 0

You can replace 5sample with 10sample to compute all the metrics (ADE_5, FDE_5, ADE_10, FDE_10). You should be able to get the numbers reported in the paper as shown in this table:

ADE_5 FDE_5 ADE_10 FDE_10
Ours 1.856 3.889 1.452 2.856

Training

You can train your own models with your customized configs. Here we take the ETH dataset as an example, but you can train models for other datasets with their corresponding configs. AgentFormer requires two-stage training:

  1. Train the AgentFormer VAE model (everything but the trajectory sampler):
    python train.py --cfg user_eth_agentformer_pre --gpu 0
    
  2. Once the VAE model is trained, train the AgentFormer DLow model (trajectory sampler):
    python train.py --cfg user_eth_agentformer --gpu 0
    
    Note that you need to change the pred_cfg field in user_eth_agentformer to the config you used in step 1 (user_eth_agentformer_pre) and change the pred_epoch to the VAE model epoch you want to use.

Citation

If you find our work useful in your research, please cite our paper AgentFormer:

@inproceedings{yuan2021agent,
  title={AgentFormer: Agent-Aware Transformers for Socio-Temporal Multi-Agent Forecasting},
  author={Yuan, Ye and Weng, Xinshuo and Ou, Yanglan and Kitani, Kris},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
  year={2021}
}

License

Please see the license for further details.

Owner
Ye Yuan
PhD student at Robotics Institute, CMU
Ye Yuan
Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

CoProtector Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

Zhensu Sun 1 Oct 26, 2021
Official implementation for the paper "Attentive Prototypes for Source-free Unsupervised Domain Adaptive 3D Object Detection"

Attentive Prototypes for Source-free Unsupervised Domain Adaptive 3D Object Detection PyTorch code release of the paper "Attentive Prototypes for Sour

Deepti Hegde 23 Oct 17, 2022
Self-Regulated Learning for Egocentric Video Activity Anticipation

Self-Regulated Learning for Egocentric Video Activity Anticipation Introduction This is a Pytorch implementation of the model described in our paper:

qzhb 13 Sep 23, 2022
Attention mechanism with MNIST dataset

[TensorFlow] Attention mechanism with MNIST dataset Usage $ python run.py Result Training Loss graph. Test Each figure shows input digit, attention ma

YeongHyeon Park 12 Jun 10, 2022
PyTorch experiments with the Zalando fashion-mnist dataset

zalando-pytorch PyTorch experiments with the Zalando fashion-mnist dataset Project Organization ├── LICENSE ├── Makefile - Makefile with co

Federico Baldassarre 31 Sep 25, 2021
Free like Freedom

This is all very much a work in progress! More to come! ( We're working on it though! Stay tuned!) Installation Open an Anaconda Prompt (in Windows, o

2.3k Jan 04, 2023
Unofficial Alias-Free GAN implementation. Based on rosinality's version with expanded training and inference options.

Alias-Free GAN An unofficial version of Alias-Free Generative Adversarial Networks (https://arxiv.org/abs/2106.12423). This repository was heavily bas

dusk (they/them) 75 Dec 12, 2022
PyTorch implementation of the YOLO (You Only Look Once) v2

PyTorch implementation of the YOLO (You Only Look Once) v2 The YOLOv2 is one of the most popular one-stage object detector. This project adopts PyTorc

申瑞珉 (Ruimin Shen) 433 Nov 24, 2022
Statsmodels: statistical modeling and econometrics in Python

About statsmodels statsmodels is a Python package that provides a complement to scipy for statistical computations including descriptive statistics an

statsmodels 8.1k Jan 02, 2023
Bag of Tricks for Natural Policy Gradient Reinforcement Learning

Bag of Tricks for Natural Policy Gradient Reinforcement Learning [ArXiv] Setup Python 3.8.0 pip install -r req.txt Mujoco 200 license Main Files main.

Brennan Gebotys 1 Oct 10, 2022
This is the face keypoint train code of project face-detection-project

face-key-point-pytorch 1. Data structure The structure of landmarks_jpg is like below: |--landmarks_jpg |----AFW |------AFW_134212_1_0.jpg |------AFW_

I‘m X 3 Nov 27, 2022
[NeurIPS-2021] Mosaicking to Distill: Knowledge Distillation from Out-of-Domain Data

MosaicKD Code for NeurIPS-21 paper "Mosaicking to Distill: Knowledge Distillation from Out-of-Domain Data" 1. Motivation Natural images share common l

ZJU-VIPA 37 Nov 10, 2022
ONNX-GLPDepth - Python scripts for performing monocular depth estimation using the GLPDepth model in ONNX

ONNX-GLPDepth - Python scripts for performing monocular depth estimation using the GLPDepth model in ONNX

Ibai Gorordo 18 Nov 06, 2022
A mini-course offered to Undergrad chemistry students

The best way to use this material is by forking it by click the Fork button at the top, right corner. Then you will get your own copy to play with! Th

Raghu 19 Dec 19, 2022
Compositional Sketch Search

Compositional Sketch Search Official repository for ICIP 2021 Paper: Compositional Sketch Search Requirements Install and activate conda environment c

Alexander Black 8 Sep 06, 2021
Everything you need to know about NumPy( Creating Arrays, Indexing, Math,Statistics,Reshaping).

Everything you need to know about NumPy( Creating Arrays, Indexing, Math,Statistics,Reshaping).

1 Feb 14, 2022
JugLab 33 Dec 30, 2022
Bagua is a flexible and performant distributed training algorithm development framework.

Bagua is a flexible and performant distributed training algorithm development framework.

786 Dec 17, 2022
Loopy belief propagation for factor graphs on discrete variables, in JAX!

PGMax implements general factor graphs for discrete probabilistic graphical models (PGMs), and hardware-accelerated differentiable loopy belief propagation (LBP) in JAX.

Vicarious 62 Dec 23, 2022
The World of an Octopus: How Reporting Bias Influences a Language Model's Perception of Color

The World of an Octopus: How Reporting Bias Influences a Language Model's Perception of Color Overview Code and dataset for The World of an Octopus: H

1 Nov 13, 2021