-
-
Notifications
You must be signed in to change notification settings - Fork 243
Open
Description
import torchxrayvision as xrv
import skimage, torch, torchvision
# Prepare the image:
#img = skimage.io.imread("16747_3_1.jpg")
img = skimage.io.imread("covid-19-pneumonia-58-prior.jpg")
#img = skimage.io.imread("test2.png")
img = xrv.datasets.normalize(img, 255) # convert 8-bit image to [-1024, 1024] range
img = img.mean(2)[None, ...] # Make single color channel
transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(224)])
#transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(512)])
img = transform(img)
img = torch.from_numpy(img)
# Load model and process image
model = xrv.models.DenseNet(weights="densenet121-res224-all")
#model = xrv.models.ResNet(weights="resnet50-res512-all")
# model = xrv.baseline_models.jfhealthcare.DenseNet()
outputs = model(img[None,...]) # or model.features(img[None,...])
# Print results
cpu_tensor = outputs[0].cpu();
result = zip(model.pathologies, cpu_tensor.detach().numpy())
result_sorted = sorted(result, key=lambda x: x[1], reverse=True)
for finding, percentage in result_sorted:
print(f"{finding}: {percentage * 100:.0f}%")I'm using this code which is pretty much the same as the code from the README. But the classification on the test image is completely wrong, as the image represents pneumonia, why?
Metadata
Metadata
Assignees
Labels
No labels
