Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Latest commit

 

History

History
25 lines (19 loc) · 656 Bytes

README.md

File metadata and controls

25 lines (19 loc) · 656 Bytes

TransUNet

An ML model with U-shaped architecture with ResNet50V2 and Vision Transformer based encoders

Install

pip install --upgrade git+https://github.com/Basars/trans-unet.git

Usage:

import numpy as np

from transunet import VisionTransformer

# Encoder weights from Google
weights = np.load('R50+ViT-B_16.npz', allow_pickle=True)

model = VisionTransformer(input_shape=(224, 224, 3), 
                          num_classes=1, 
                          w=weights, 
                          encoder_trainable=False)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.fit(...)