diff --git a/med_seg_diff_pytorch/med_seg_diff_pytorch.py b/med_seg_diff_pytorch/med_seg_diff_pytorch.py index 11bdfa3..343b0c7 100644 --- a/med_seg_diff_pytorch/med_seg_diff_pytorch.py +++ b/med_seg_diff_pytorch/med_seg_diff_pytorch.py @@ -236,6 +236,7 @@ def __init__( self, dim, *, + image_size, patch_size, channels = 3, channels_out = None, @@ -244,6 +245,13 @@ def __init__( depth = 4, ): super().__init__() + assert exists(image_size) + assert (image_size % patch_size) == 0 + + num_patches_height_width = image_size // patch_size + + self.pos_emb = nn.Parameter(torch.zeros(dim, num_patches_height_width, num_patches_height_width)) + channels_out = default(channels_out, channels) patch_dim = channels * (patch_size ** 2) @@ -272,6 +280,8 @@ def __init__( def forward(self, x): x = self.to_tokens(x) + x = x + self.pos_emb + x = self.transformer(x) return self.to_patches(x) @@ -283,6 +293,7 @@ def __init__( fmap_size, dim, dynamic = True, + image_size = None, dim_head = 32, heads = 4, depth = 4, @@ -412,6 +423,7 @@ def __init__( if conditioning_klass == Conditioning: conditioning_klass = partial( Conditioning, + image_size = image_size, dynamic = dynamic_ff_parser_attn_map, **conditioning_kwargs ) diff --git a/setup.py b/setup.py index dd768b5..cdff3cf 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'med-seg-diff-pytorch', packages = find_packages(exclude=[]), - version = '0.3.0', + version = '0.3.1', license='MIT', description = 'MedSegDiff - SOTA medical image segmentation - Pytorch', author = 'Phil Wang',