Skip to content

Commit

Permalink
absolute positional embedding for vision transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 21, 2023
1 parent f9447b0 commit 11dc571
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
12 changes: 12 additions & 0 deletions med_seg_diff_pytorch/med_seg_diff_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def __init__(
self,
dim,
*,
image_size,
patch_size,
channels = 3,
channels_out = None,
Expand All @@ -244,6 +245,13 @@ def __init__(
depth = 4,
):
super().__init__()
assert exists(image_size)
assert (image_size % patch_size) == 0

num_patches_height_width = image_size // patch_size

self.pos_emb = nn.Parameter(torch.zeros(dim, num_patches_height_width, num_patches_height_width))

channels_out = default(channels_out, channels)

patch_dim = channels * (patch_size ** 2)
Expand Down Expand Up @@ -272,6 +280,8 @@ def __init__(

def forward(self, x):
x = self.to_tokens(x)
x = x + self.pos_emb

x = self.transformer(x)
return self.to_patches(x)

Expand All @@ -283,6 +293,7 @@ def __init__(
fmap_size,
dim,
dynamic = True,
image_size = None,
dim_head = 32,
heads = 4,
depth = 4,
Expand Down Expand Up @@ -412,6 +423,7 @@ def __init__(
if conditioning_klass == Conditioning:
conditioning_klass = partial(
Conditioning,
image_size = image_size,
dynamic = dynamic_ff_parser_attn_map,
**conditioning_kwargs
)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'med-seg-diff-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.0',
version = '0.3.1',
license='MIT',
description = 'MedSegDiff - SOTA medical image segmentation - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 11dc571

Please sign in to comment.