Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch.

Overview

SE3 Transformer - Pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 results and other drug discovery applications.

Install

$ pip install se3-transformer-pytorch

Usage

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 512,
    heads = 8,
    depth = 6,
    dim_head = 64,
    num_degrees = 4,
    valid_radius = 10
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
mask  = torch.ones(1, 1024).bool()

out = model(feats, coors, mask) # (1, 1024, 512)

Potential example usage in Alphafold2, as outlined here

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True
)

atom_feats = torch.randn(2, 32, 64)
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refinement = model(atom_feats, coors, mask, return_type = 1) # (2, 32, 3)

You can also let the base transformer class take care of embedding the type 0 features being passed in. Assuming they are atoms

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,       # 28 unique atoms
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refinement = model(atoms, coors, mask, return_type = 1) # (2, 32, 3)

If you think the net could further benefit from positional encoding, you can featurize your positions in space and pass it in as follows.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 2,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True  # reduce out the final dimension
)

atom_feats  = torch.randn(2, 32, 64, 1) # b x n x d x type0
coors_feats = torch.randn(2, 32, 64, 3) # b x n x d x type1

# atom features are type 0, predicted coordinates are type 1
features = {'0': atom_feats, '1': coors_feats}
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refinement = model(features, coors, mask, return_type = 1) # (2, 32, 3) - equivariant to input type 1 features and coordinates

Edges

To offer edge information to SE3 Transformers (say bond types between atoms), you just have to pass in two more keyword arguments on initialization.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,
    dim = 64,
    num_edge_tokens = 4,       # number of edge type, say 4 bond types
    edge_dim = 16,             # dimension of edge embedding
    depth = 2,
    input_degrees = 1,
    num_degrees = 3,
    output_degrees = 1,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
bonds = torch.randint(0, 4, (2, 32, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

pred = model(atoms, coors, mask, edges = bonds, return_type = 0) # (2, 32, 1)

Caching

By default, the basis vectors are cached. However, if there is ever the need to clear the cache, you simply have to set the environmental flag CLEAR_CACHE to some value on initiating the script

$ CLEAR_CACHE=1 python train.py

Or you can try deleting the cache directory, which should exist at

$ rm -rf ~/.cache.equivariant_attention

Testing

$ python setup.py pytest

Credit

This library is largely a port of Fabian's official repository, but without the DGL library.

Citations

@misc{fuchs2020se3transformers,
    title   = {SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks}, 
    author  = {Fabian B. Fuchs and Daniel E. Worrall and Volker Fischer and Max Welling},
    year    = {2020},
    eprint  = {2006.10503},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Breaking equivariance

    Breaking equivariance

    Hi, Thanks a lot for your work!!

    I was running some equivariance tests on your implementation of the SE3-Transformer and found that it is not always conserved. It does not break every time and unfortunately I do not know where the bug is.

    I have appended an image with an example of equivariance not being conserved.

    image

    To generate this image I used the following code:

        import numpy as np
        import matplotlib.pyplot as plt
        import torch
        
        from se3_transformer_pytorch import SE3Transformer
    

    Make some data

        zline = np.arange(0, 2, 0.05)
        xline = np.sin(zline * 2 * np.pi) 
        yline = np.cos(zline * 2 * np.pi)
        points = np.array([xline, yline, zline])
        geom = torch.tensor(points.transpose())[None,:].float()
        feat = torch.randint(0, 20, (1, geom.shape[1],1)).float()
        
        def rot_matrix(x):
            # Rotation matrix
            a ,b ,c = 2*np.pi*x
            return np.array([
                [np.cos(a)*np.cos(b), np.cos(a)*np.sin(b)*np.sin(c)- np.sin(a)*np.cos(c), np.cos(a)*np.sin(b)*np.cos(c)+ np.sin(a)*np.sin(c)],
                [np.sin(a)*np.cos(b), np.sin(a)*np.sin(b)*np.sin(c)+ np.cos(a)*np.cos(c), np.sin(a)*np.sin(b)*np.cos(c)- np.cos(a)*np.sin(c)],
                [-np.sin(b)         , np.cos(b)*np.sin(c)                               , np.cos(b)*np.cos(c)                               ]
            ])
    

    Initialize model

        mdl = SE3Transformer(
            dim = 1,
            depth = 3,
            input_degrees = 1,
            num_degrees = 2,
            output_degrees = 2,
            reduce_dim_out = True,
        )
        
        def model(geom,feat):
            return geom + mdl(feat,geom, return_type = 1)
    

    Check Rotation Invariance:

        with torch.no_grad():
            
            Q = torch.tensor(rot_matrix(np.random.random(3))).float()
            prerotated = model(geom @ Q, feat).squeeze().detach().numpy().transpose()
            posrotated = (model(geom, feat) @ Q).squeeze().detach().numpy().transpose()
        
            fig = plt.figure(dpi = 200)
            ax = plt.axes(projection="3d")
        
            ax.plot3D(prerotated[0], prerotated[1], prerotated[2], "r", linewidth=1.1)
            ax.plot3D(posrotated[0], posrotated[1], posrotated[2], "b", linewidth=0.5)
        
            plt.legend(["Pre-Rotated", "Post-Rotated"])
            plt.show()
    

    Check Translation Invariance:

        with torch.no_grad():
            
            x0 = 1*torch.rand(3)
            prerotated = model(geom + x0, feat).squeeze().detach().numpy().transpose()
            posrotated = (model(geom, feat) + x0).squeeze().detach().numpy().transpose()
        
            fig = plt.figure(dpi = 200)
            ax = plt.axes(projection="3d")
        
            ax.plot3D(prerotated[0], prerotated[1], prerotated[2], "r", linewidth=1.1)
            ax.plot3D(posrotated[0], posrotated[1], posrotated[2], "b", linewidth=0.5)
        
            plt.legend(["Pre-Translated", "Post-Translated"])
            plt.show()
    

    Hope this helps!!

    opened by brennanaba 18
  • How to populate input variable length data

    How to populate input variable length data

    Hello, thank you for your work. I am using the implementation of your SE3-Transformer as part of my project, but I have encountered some problems when processing variable length data input into your model. I do not know how to fill in features, coordinates and masks to meet the needs of the model.

    The processing of input data in my DataSet is as follows: image I fill the coordinates and features with zeros to ensure that their input dimensions are (246,3) and (246,20) respectively, and fill the mask with true and false bool types. In this case,there are 49 valid values,.So the shapes with valid features and coordinates are (49,20) and (49,3) and the rest are padded to zero. The first 49 of the mask is true and the rest is false

    Then I check through the torch.utils.data.DataLoader get input dimension is no problem: image My network looks like this, A typical binary classifier,I output the result of se3:

    image

    However, the output of se3 is NaN as shown in the figure below. image The first 48 have results, the rest are NaN, not clear what the problem is

    The next batch may all be NaN due to model weight changes image The output of SE3 causes all my losses to be NaN

    I think these NaN are probably caused by my incorrect filling method. I have also tried to fill in "1", but it was also ineffective. Maybe the mask I typed in can't have false, and I'm confused about that. Could you please tell me how to correctly fill coordinates, features and masks or give me some advice on using your SE3 model to handle variable length data. Hope this helps! ! thank you

    opened by zyk19981118 8
  • CPU/CUDA masking error

    CPU/CUDA masking error

    Hi - nice work - I was just testing out your code and ran into the following error, but only in the backward pass:

    RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'mask' in call to th_masked_scatter_bool

    When using the nightly Pytorch, the error message is:

    RuntimeError: Tensor for argument #2 'mask' is on CPU, but expected it to be on GPU (while checking arguments for masked_scatter_)

    I'm pretty sure I don't have any tensors in CPU memory, but not sure if this is a bug in your SE3 code or a Pytorch issue. My gut feeling is this is a Pytorch/autograd issue, but I just don't know these particular Pytorch ops well enough to be sure. Tried both 1.7.1 release Pytorch and the latest nightly. Seems like there has been recent work on masked_scatter according to Pytorch issues.

    opened by denjots 8
  • small bug

    small bug

    https://github.com/lucidrains/se3-transformer-pytorch/blob/7c79998e4d84ec6bd6b6d4b916c6bf30b870b75b/se3_transformer_pytorch/se3_transformer_pytorch.py#L301

    should be if isinstance(m,nn.Linear)

    nbd, but thought you might wanna know.

    opened by MattMcPartlon 5
  • INTERNAL ASSERT FAILED

    INTERNAL ASSERT FAILED

    Hi - I'm not sure if this is actually a bug or if I'm just expecting too much, but the following code snippet bombs with a PyTorch internal assert failure:

    RuntimeError: sub_iter.strides(0)[0] == 0INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1618643394934/work/aten/src/ATen/native/cuda/Reduce.cuh":929, please report a bug to PyTorch.

    That's using the latest PyTorch nightly. Tried with both CUDA 10.2 and 11.1.

    Looking at the PyTorch bug reports, this seems to suggest that there is massive internal memory allocation being triggered when summing over a large tensor. It seems to have been sitting open for a year and I'm not sure if they even see it as an actual bug there. Certainly it doesn't seem to be a priority.

    Problem with this is that this isn't a particularly large data input, or a large model (to say the least!), and my GPU has 40 Gb of RAM, so, scalability of this particular transformer model seems very minimal as it currently stands. Changing the dim from 128 to 64 does at least allow it to run.

    import torch
    from se3_transformer_pytorch import SE3Transformer
    
    model = SE3Transformer(dim = 128, heads = 1, depth = 1, dim_head = 1, num_degrees = 1, input_degrees=1, output_degrees=1).cuda()
    
    feats = torch.randn(1, 200, 128).cuda()
    coords = torch.randn(1, 200, 3).cuda()
    
    out = model(feats, coords)
    
    
    opened by denjots 4
  • question about non scalar output

    question about non scalar output

    Hello,

    I am interested in using your implementation on my dataset. There are two things I want to check with you

    1. The property I want to predict is a 3x3 symmetric PSD matrix.
    2. There are some edge features (one categorical feature, one continuous feature) besides the coordinate difference

    I was wondering does the current se3-transformer can work with such a scenario? Thanks!

    opened by Chen-Cai-OSU 3
  • Reversible flag odd results

    Reversible flag odd results

    Sorry if it's expected behaviour again, but with a slight tweak of your new example code I get unexpected results when your new reversible option is used to return type 1 data. Here's the code I'm running...

    import torch
    from se3_transformer_pytorch import SE3Transformer
    
    model = SE3Transformer(
        num_tokens = 20,
        dim = 32,
        dim_head = 32,
        heads = 4,
        depth = 12,             # 12 layers
        input_degrees = 1,
        num_degrees = 2,
        output_degrees = 2,
        reduce_dim_out = True,
        reversible = True       # set reversible to True
    ).cuda()
    
    atoms = torch.randint(0, 4, (1, 50)).cuda()
    coors = torch.randn(1, 50, 3).cuda()
    mask  = torch.ones(1, 50).bool().cuda()
    
    pred = model(atoms, coors, mask = mask, return_type = 1)
    
    loss = pred.mean()
    print(loss)
    loss.backward()
    
    

    Without the reversible flag, the loss is close to zero as might be expected as the input coords are centered on the origin. However, when reversible is set, a very high value is produced which is going to blow up training if this was for real. Is this an expected side effect of reversible nets? Is some kind of normalization essential in that case? Or am I just doing something wrong here?

    opened by denjots 3
  • faster loop

    faster loop

    My small grain on sand for this project ;) : at least don't deal with python appends which are 2x slower than list comprehension.

    If this is of any help, here are some considerations (there might be misunderstandings on my side due to the decorators and so on):

    • i have my reservations on the utility of the line 62 in utils.py
    • would't make more sense to start the for i modified (line 122 of utils.py) in reverse order, then use the cached calculations for the lpmv() ?
    • same case (loop in reverse order) for the for in line 148 of basis.py?
    • if using the scipy.special.poch (which can deal with np arrays) instead of the custom pochhammer implementation, all operations inside get_spherical_harmonics_element are vectorizeable but the lpmv function call.
      • My sense is that the lpmv, get_spherical_harmonics_element and get_spherical_harmonics could be all wrapped in a single function (lower reusability / extension... so maybe doing the inverse loop order and caching is enough).
    opened by hypnopump 1
  • Question about continuous edge features

    Question about continuous edge features

    Thanks for all of the work!

    I am working on a simple proof of concept with your model. Ideally, I would like to perform multidimensional scaling. i.e. given a distance matrix recover corresponding coordinates.

    I was wondering if distance information could be passed as an edge feature (continuous rather than categorical information). Is it possible to do this with the current implementation?

    Thanks again and I appreciate the help!

    opened by MattMcPartlon 1
  • denoise.py bugfix

    denoise.py bugfix

    Fixes issue related to the constructor

    Traceback (most recent call last):
      File "/home/jcastellanos/projects/se3-transformer-pytorch/denoise.py", line 22, in <module>
        transformer = SE3Transformer(
      File "/home/jcastellanos/projects/se3-transformer-pytorch/se3_transformer_pytorch/se3_transformer_pytorch.py", line 1072, in __init__
        self.num_degrees = num_degrees if exists(num_degrees) else (max(hidden_fiber_dict.keys()) + 1)
    AttributeError: 'NoneType' object has no attribute 'keys'
    
    opened by javierbq 0
  • CUDA out of memory

    CUDA out of memory

    Thanks for your great job!

    The se3-transformer is powerful, but seems to be memory exhaustive.

    I built a model with the following parameters, and got "CUDA out of memory error" when I run it on the GPU(Nvidia V100 / 32G).

    model = SE3Transformer( dim = 20, heads = 4, depth = 2, dim_head = 5, num_degrees = 2, valid_radius = 5 )

    num_points = 512
    feats = torch.randn(1, num_points, 20)
    coors = torch.randn(1, num_points, 3)
    mask = torch.ones(1, num_points).bool()
    

    Does this error relate to the version of pytorch? and how can I fix it?

    opened by PengCheng-NUDT 0
  • SE3Transformer constructor hangs

    SE3Transformer constructor hangs

    I am trying to run an example from the README. The code is:

    import torch
    from se3_transformer_pytorch import SE3Transformer
    
    print('Initialising model...')
    model = SE3Transformer(
        dim = 512,
        heads = 8,
        depth = 6,
        dim_head = 64,
        num_degrees = 4,
        valid_radius = 10
    )
    
    print('Running model...')
    feats = torch.randn(1, 1024, 512)
    coors = torch.randn(1, 1024, 3)
    mask  = torch.ones(1, 1024).bool()
    
    out = model(feats, coors, mask) # (1, 1024, 512)
    

    The output hangs on 'Initialising model...' and eventually the kernel dies.

    Any ideas why this would be happening?

    Here is my pip freeze:

    anyio==3.2.1
    argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1613036642480/work
    astunparse==1.6.3
    async-generator==1.10
    attrs @ file:///tmp/build/80754af9/attrs_1620827162558/work
    axial-positional-embedding==0.2.1
    Babel==2.9.1
    backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
    biopython==1.79
    bleach @ file:///tmp/build/80754af9/bleach_1612211392645/work
    cached-property @ file:///tmp/build/80754af9/cached-property_1600785575025/work
    certifi==2021.5.30
    cffi @ file:///tmp/build/80754af9/cffi_1613246939562/work
    chardet==4.0.0
    click==8.0.1
    configparser==5.0.2
    decorator==4.4.2
    defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
    dgl-cu101==0.4.3.post2
    dgl-cu110==0.6.1
    docker-pycreds==0.4.0
    egnn-pytorch==0.2.6
    einops==0.3.0
    En-transformer==0.3.8
    entrypoints==0.3
    equivariant-attention @ file:///workspace/projects/se3-transformer-public
    filelock==3.0.12
    gitdb==4.0.7
    GitPython==3.1.18
    graph-transformer-pytorch==0.0.1
    h5py @ file:///tmp/build/80754af9/h5py_1622088444809/work
    huggingface-hub==0.0.12
    idna==2.10
    importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1617877314848/work
    ipykernel @ file:///tmp/build/80754af9/ipykernel_1596206598566/work/dist/ipykernel-5.3.4-py3-none-any.whl
    ipython @ file:///tmp/build/80754af9/ipython_1617118429768/work
    ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
    ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1610481889018/work
    jedi==0.17.0
    Jinja2 @ file:///tmp/build/80754af9/jinja2_1621238361758/work
    joblib==1.0.1
    json5==0.9.6
    jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work
    jupyter==1.0.0
    jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1616770841739/work
    jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1616615302928/work
    jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1612213308260/work
    jupyter-server==1.8.0
    jupyter-tensorboard==0.2.0
    jupyterlab==3.0.16
    jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
    jupyterlab-server==2.6.0
    jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work
    jupytext==1.11.3
    lie-learn @ git+https://github.com/AMLab-Amsterdam/[email protected]
    llvmlite==0.36.0
    local-attention==1.4.1
    markdown-it-py==1.1.0
    MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1621528142364/work
    mdit-py-plugins==0.2.8
    mdtraj==1.9.6
    mistune @ file:///tmp/build/80754af9/mistune_1594373098390/work
    mkl-fft==1.3.0
    mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853974840/work
    mkl-service==2.3.0
    mp-nerf==0.1.11
    nbclassic==0.3.1
    nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work
    nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914821128/work
    nbformat @ file:///tmp/build/80754af9/nbformat_1617383369282/work
    nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work
    networkx==2.5.1
    notebook @ file:///tmp/build/80754af9/notebook_1621523661196/work
    numba==0.53.1
    numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1620831194891/work
    packaging @ file:///tmp/build/80754af9/packaging_1611952188834/work
    pandas==1.2.4
    pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120451932/work
    parso @ file:///tmp/build/80754af9/parso_1617223946239/work
    pathtools==0.1.2
    performer-pytorch==1.0.11
    pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
    pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
    ProDy==2.0
    prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1623189609245/work
    promise==2.3
    prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work
    protobuf==3.17.3
    psutil==5.8.0
    ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
    py3Dmol==0.9.1
    pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
    Pygments @ file:///tmp/build/80754af9/pygments_1621606182707/work
    pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work
    pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141707582/work
    python-dateutil @ file:///home/ktietz/src/ci/python-dateutil_1611928101742/work
    pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work
    PyYAML==5.4.1
    pyzmq==20.0.0
    qtconsole @ file:///tmp/build/80754af9/qtconsole_1623278325812/work
    QtPy==1.9.0
    regex==2021.4.4
    requests==2.25.1
    sacremoses==0.0.45
    scipy @ file:///tmp/build/80754af9/scipy_1618852618548/work
    se3-transformer-pytorch==0.8.10
    Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work
    sentry-sdk==1.1.0
    shortuuid==1.0.1
    sidechainnet==0.6.0
    six @ file:///tmp/build/80754af9/six_1623709665295/work
    smmap==4.0.0
    sniffio==1.2.0
    subprocess32==3.5.4
    terminado==0.9.4
    testpath @ file:///home/ktietz/src/ci/testpath_1611930608132/work
    tokenizers==0.10.3
    toml==0.10.2
    torch==1.9.0
    tornado @ file:///tmp/build/80754af9/tornado_1606942283357/work
    tqdm==4.61.1
    traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work
    transformers==4.8.0
    typing-extensions @ file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work
    urllib3==1.26.5
    wandb==0.10.32
    wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
    webencodings==0.5.1
    websocket-client==1.1.0
    widgetsnbextension==3.5.1
    zipp @ file:///tmp/build/80754af9/zipp_1615904174917/work
    

    Here is a summary of my system info (lshw -short):

    H/W path    Device  Class      Description
    ==========================================
                        system     Computer
    /0                  bus        Motherboard
    /0/0                memory     59GiB System memory
    /0/1                processor  Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
    /0/100              bridge     440FX - 82441FX PMC [Natoma]
    /0/100/1            bridge     82371SB PIIX3 ISA [Natoma/Triton II]
    /0/100/1.1          storage    82371SB PIIX3 IDE [Natoma/Triton II]
    /0/100/1.3          bridge     82371AB/EB/MB PIIX4 ACPI
    /0/100/2            display    GD 5446
    /0/100/3            network    Elastic Network Adapter (ENA)
    /0/100/1e           display    GK210GL [Tesla K80]
    /0/100/1f           generic    Xen Platform Device
    /1          eth0    network    Ethernet interface
    
    opened by mpdprot 1
  • Whether SE3 needs pre-training

    Whether SE3 needs pre-training

    Thank you for your work. I used your reproduced SE3 as a part of my model, but the current test effect is not very good. I guess it may be because I do not have a good understanding of your model. Here are my questions:

    1. Does your model need pre-training?
    2. Can I train SE3 Transformer with the full connection layer that comes after it? Good advice is also welcome
    opened by zyk19981118 2
  • multiple molecules cases

    multiple molecules cases

    Hi,

    I use normally dataloader from PyG to handle my molecules dataset.

    Can you provide an example to make a real multistep epoch model please ?

    I run your sample code it's working but I need better understand input format as well as output format which for me would be x,y,z ?

    image

    thanks

    opened by thegodone 2
Releases(0.9.0)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
基于PaddleClas实现垃圾分类,并转换为inference格式用PaddleHub服务端部署

百度网盘链接及提取码: 链接:https://pan.baidu.com/s/1HKpgakNx1hNlOuZJuW6T1w 提取码:wylx 一个垃圾分类项目带你玩转飞桨多个产品(1) 基于PaddleClas实现垃圾分类,导出inference模型并利用PaddleHub Serving进行服务

thomas-yanxin 22 Jul 12, 2022
Some experiments with tennis player aging curves using Hilbert space GPs in PyMC. Only experimental for now.

NOTE: This is still being developed! Setup notes This document uses Jeff Sackmann's tennis data. You can obtain it as follows: git clone https://githu

Martin Ingram 1 Jan 20, 2022
Image super-resolution (SR) is a fast-moving field with novel architectures attracting the spotlight

Revisiting RCAN: Improved Training for Image Super-Resolution Introduction Image super-resolution (SR) is a fast-moving field with novel architectures

Zudi Lin 76 Dec 01, 2022
Recognize numbers from an (28 x 28) image using neural networks

Number recognition Recognize numbers from a 28 x 28 image using neural networks Usage This is an example of a simple usage of number-recognition NOTE:

Mauro Baladés 2 Dec 29, 2021
Official pytorch implementation of Rainbow Memory (CVPR 2021)

Rainbow Memory: Continual Learning with a Memory of Diverse Samples

Clova AI Research 91 Dec 17, 2022
ANN model for prediction a spatio-temporal distribution of supercooled liquid in mixed-phase clouds using Doppler cloud radar spectra.

VOODOO Revealing supercooled liquid beyond lidar attenuation Explore the docs » Report Bug · Request Feature Table of Contents About The Project Built

remsens-lim 2 Apr 28, 2022
Implementation based on Paper - Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling

Implementation based on Paper - Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling

HamasKhan 3 Jul 08, 2022
Functional TensorFlow Implementation of Singular Value Decomposition for paper Fast Graph Learning

tf-fsvd TensorFlow Implementation of Functional Singular Value Decomposition for paper Fast Graph Learning with Unique Optimal Solutions Cite If you f

Sami Abu-El-Haija 14 Nov 25, 2021
[EMNLP 2020] Keep CALM and Explore: Language Models for Action Generation in Text-based Games

Contextual Action Language Model (CALM) and the ClubFloyd Dataset Code and data for paper Keep CALM and Explore: Language Models for Action Generation

Princeton Natural Language Processing 43 Dec 16, 2022
Lex Rosetta: Transfer of Predictive Models Across Languages, Jurisdictions, and Legal Domains

Lex Rosetta: Transfer of Predictive Models Across Languages, Jurisdictions, and Legal Domains This is an accompanying repository to the ICAIL 2021 pap

4 Dec 16, 2021
PAthological QUpath Obsession - QuPath and Python conversations

PAQUO: PAthological QUpath Obsession Welcome to paquo 👋 , a library for interacting with QuPath from Python. paquo's goal is to provide a pythonic in

Bayer AG 60 Dec 31, 2022
Code for EMNLP 2021 paper Contrastive Out-of-Distribution Detection for Pretrained Transformers.

Contra-OOD Code for EMNLP 2021 paper Contrastive Out-of-Distribution Detection for Pretrained Transformers. Requirements PyTorch Transformers datasets

Wenxuan Zhou 27 Oct 28, 2022
Code for ICCV 2021 paper "Distilling Holistic Knowledge with Graph Neural Networks"

HKD Code for ICCV 2021 paper "Distilling Holistic Knowledge with Graph Neural Networks" cifia-100 result The implementation of compared methods are ba

Wang Yucheng 30 Dec 18, 2022
Voila - Voilà turns Jupyter notebooks into standalone web applications

Rendering of live Jupyter notebooks with interactive widgets. Introduction Voilà turns Jupyter notebooks into standalone web applications. Unlike the

Voilà Dashboards 4.5k Jan 03, 2023
Hippocampal segmentation using the UNet network for each axis

Hipposeg Hippocampal segmentation using the UNet network for each axis, inspired by https://github.com/MICLab-Unicamp/e2dhipseg Red: False Positive Gr

Juan Carlos Aguirre Arango 0 Sep 02, 2021
Official PyTorch implementation of Spatial Dependency Networks.

Spatial Dependency Networks: Neural Layers for Improved Generative Image Modeling Đorđe Miladinović   Aleksandar Stanić   Stefan Bauer   Jürgen Schmid

Djordje Miladinovic 34 Jan 19, 2022
Measuring and Improving Consistency in Pretrained Language Models

ParaRel 🤘 This repository contains the code and data for the paper: Measuring and Improving Consistency in Pretrained Language Models as well as the

Yanai Elazar 26 Dec 02, 2022
UT-Sarulab MOS prediction system using SSL models

UTMOS: UTokyo-SaruLab MOS Prediction System Official implementation of "UTMOS: UTokyo-SaruLab System for VoiceMOS Challenge 2022" submitted to INTERSP

sarulab-speech 58 Nov 22, 2022
Adversarial Self-Defense for Cycle-Consistent GANs

Adversarial Self-Defense for Cycle-Consistent GANs This is the official implementation of the CycleGAN robust to self-adversarial attacks used in pape

Dina Bashkirova 10 Oct 10, 2022
NVIDIA Deep Learning Examples for Tensor Cores

NVIDIA Deep Learning Examples for Tensor Cores Introduction This repository provides State-of-the-Art Deep Learning examples that are easy to train an

NVIDIA Corporation 10k Dec 31, 2022