The code used to train SwinUNETR for tumor segmentation:
from monai.networks.nets import SwinUNETR
from monai import transforms
from monai.inferers import sliding_window_inference
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
roi = 128
batch_size = 1
device = "cuda"
modalities = ["flair", "t1ce", "t1", "t2"]
model = SwinUNETR(
img_size=roi,
in_channels=4,
out_channels=3,
feature_size=48,
drop_rate=0.0,
attn_drop_rate=0.0,
dropout_path_rate=0.0,
use_checkpoint=True,
)
pretrained_pth = "/mnt/Data/SwinUNETR_BraTS_weights/fold0_f48_ep300_4gpu_dice0_8854/model.pt"
model_dict = torch.load(pretrained_pth, map_location=torch.device(device))["state_dict"]
model.load_state_dict(model_dict)
model = model.to(device)
test_transform = transforms.Compose(
[
transforms.LoadImaged(keys="image", image_only=False),
transforms.EnsureChannelFirstd(keys="image", channel_dim="no_channel"),
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
transforms.ToTensord(keys="image"),
]
)
path_data = Path("/mnt/Data/SOW2/brats2021_task1/BraTS2021_Training_Data/")
im_dict_list = [{"image": [path_data / d.name / f"{d.name}_{m}.nii.gz" for m in modalities]} for d in sorted(path_data.iterdir()) if "BraTS" in d.name]
df = pd.DataFrame([])
for i, im_dict in enumerate(tqdm(im_dict_list)):
x = test_transform([im_dict])
im = x[0]["image"].to(device)
out = sliding_window_inference(im, roi, batch_size, model.swinViT)[4].mean(axis=(2, 3, 4))[0]
out = np.asarray(out.detach().to("cpu"))
df = pd.concat([df, pd.DataFrame(out).T])
df.to_csv("/mnt/Data/example/feats_000.csv")
from scipy.ndimage import center_of_mass
def nii_loader(path):
img_vol = nib.load(path)
return img_vol.get_fdata().T
vols_dir = Path("/mnt/Data/SOW2/brats2021_task1/BraTS2021_Training_Data/")
vols_names = [f.name for f in sorted(vols_dir.iterdir()) if "BraTS" in f.name]
vols_labels_paths = [vols_dir / f / f"{f}_seg.nii.gz" for f in vols_names]
vols_centers = [center_of_mass(nii_loader(vol_path)) for vol_path in tqdm(vols_labels_paths)]
vols_centers_array = np.asarray(vols_centers)
df_centers = pd.DataFrame({"Name": vols_names, "x": vols_centers_array[:, 2], "y": vols_centers_array[:, 1], "z": vols_centers_array[:, 0]})
df_centers.to_csv("/mnt/Data/example/centers.csv")
from sklearn.manifold import TSNE
import plotly.express as px
X_embedded = TSNE(n_components=2, learning_rate='auto', init='pca', perplexity=50).fit_transform(df)
df_embedded = pd.DataFrame(X_embedded, columns=["Feature_1", "Feature_2"])
df_embedded["Name"] = vols_names
df_embedded["position"] = df_centers["x"]
fig = px.scatter(df_embedded, x="Feature_1", y="Feature_2", hover_name="Name", color="position", width=800, height=400)
fig.show()