-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Open
Description
Environment: lime -- 0.2.0.1
I just followed the document and just changed the image path.
def get_pil_transform():
transf = transforms.Compose([
transforms.Resize((224, 224)),
# transforms.CenterCrop(224)
])
return transf
def batch_predict(images):
model.eval()
batch = torch.stack(tuple(eval_transforms(i) for i in images), dim=0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
batch = batch.to(device)
logits, _ = model(batch)
probs = F.softmax(logits, dim=1)
return probs.detach().cpu().numpy()
from lime import lime_image
img = get_image('../data_finetune/sample.jpg')
pil_transform = get_pil_transform()
transformed_img = pil_transform(img)
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(np.array(transformed_img),
batch_predict, # classification function
top_labels=1,
hide_color=0,
num_samples=1000) # number of images that will be sent to classification function
The shape of transformed_img is (224,224,3) and type of "np.adday(transformed_img)" is numpy.ndarray.
However, report the error as follows:



Could you help me solve this bug? Thank you so much for your kind help. I hope you have a good day.
Best regards
Metadata
Metadata
Assignees
Labels
No labels