diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3c68e212..fedd724e 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,6 +9,7 @@ The format is based on `Keep a Changelog ` Unreleased ---------- +- Fixed bug with asserting raster legend labels (@nkorinek, #163) - Changed changelog to an rst file. (@nkorinek, #266) - Add a vignette for testing vector data plots. (@nkorinek, #208) - Add ``pillow`` as a dev requirement (@lwasser, #253) diff --git a/matplotcheck/cases.py b/matplotcheck/cases.py index c2370987..9783a698 100644 --- a/matplotcheck/cases.py +++ b/matplotcheck/cases.py @@ -769,7 +769,7 @@ def test_image_mask(self): not im_classified, "Image not expected to be classified" ) def test_legend_accuracy(self): - self.rt.assert_legend_accuracy_classified_image( + self.rt.assert_legend_labels( im_expected=im_expected, all_label_options=legend_labels ) diff --git a/matplotcheck/raster.py b/matplotcheck/raster.py index 822daea2..3748847a 100644 --- a/matplotcheck/raster.py +++ b/matplotcheck/raster.py @@ -56,111 +56,106 @@ def assert_colorbar_range(self, crange): cb[0].vmax == crange[1] ), "Colorbar maximum is not expected value:{0}".format(crange[1]) - def _which_label(self, label, all_label_options): - """Helper function for assert_legend_accuracy_classified_image - Returns string that represents a category label for label. - - Parameters - ---------- - label: string from legend to see if it contains an option in - all_label_options - all_label_options: list of lists - Each internal list represents a class and said list is a list of - strings where at least one string is expected to be in the legend - label for this category. + def get_legend_labels(self): + """Return labels from legend in a list Returns - ------ - string that is the first entry in the internal list which label is - matched with. If no match is found, return value is None + ------- + labels: List + List of labels found in the legend of a raster plot. """ - for label_opts in all_label_options: - for s in label_opts: - if s in label: - return label_opts[0] - return None - - def assert_legend_accuracy_classified_image( - self, im_expected, all_label_options - ): - """Asserts legend correctly describes classified image on Axes ax, - checking the legend labels and the values - Parameters - ---------- - im_expected: array of arrays with expected classified image on ax. - Class values must start with 0, 1, 2, etc. - all_label_options: list of lists - Each internal list represents a class and said list is a list of - strings where at least one string is expected to be in the legend - label for this category. Internal lists must be in the same order - as bins in im_expected, e.g. first internal list has the expected - label options for class 0. + # Retrieve legend + legends = self.get_legends() + # TODO add better error message -- make this a try except as + # get + # legends should return an error if no legends exist + assert legends, "No legend displayed" - Returns - ---------- - Nothing (if checks pass) or raises error + # Get each patch stored in the legends object + patches = [leg.get_patches() for leg in legends] + # Grab rgb, alpha color and associated label for each patch + # TODO: this is a nested list because patches is returned as a list + # above. could the patches object every have more than one sublist? + # TODO because we are making this power case here we need to ensure + # the expected labels list are also lower case. then we need tests + # for upper and lower case labels in expected labels and in the + # plot legend to ensure this works. current it fails if upper case + # expected labels are provided but lowercase is in the legend. + # to simplify will + label_dict = {} - Notes - ---------- - First compares all_label_options against the legend labels to find - which element of all_label_options matches that entry. E.g. if the - first legend entry has a match in the first list in all_label_options, - then that legend entry corresponds to the first class (value 0). - Then the plot image array is copied and the values are set to the - legend label that match the values (i.e. the element in - all_label_options). The same is done for the expected image array. - Finally those two arrays of strings are compared. Passes if they match. - """ - # Retrieve image array - im_data = [] - if self.ax.get_images(): - im = self.ax.get_images()[0] - im_data, im_cmap = im.get_array(), im.get_cmap() - assert list(im_data), "No Image Displayed" + # Iterate through each patch (legend box) and grab label and facecolor + for a_patch in patches[0]: + label = a_patch.get_label().lower() + label_dict[label] = {"color": a_patch.get_facecolor()} - # Retrieve legend - legends = self.get_legends() - assert legends, "No legend displayed" + return label_dict + + def _check_label(self, labels, expected_labels): + """Helper function for assert_legend_labels + Tests each label in the legend to see if the text in expected labels + matches the text found in the legend labels. - # Retrieve legend entries and find which element of all_label_options - # matches that entry - legend_dict = {} - for p in [ - p - for sublist in [leg.get_patches() for leg in legends] - for p in sublist - ]: - label = p.get_label().lower() - legend_dict[p.get_facecolor()] = self._which_label( - label, all_label_options + Parameters + ---------- + # TODO: update all parameters and associated parameter description + # input -- dictionary now for labels object + labels: string from legend to see if it contains an option in + expected_labels + expected_labels: list of lists + Each list within the main list should contain a list of strings + that are expected to be found in each label in the plot + legend that is being tested. + TODO: clarify if this is or or "and" - ie i think it's or - is + just makes sure that one of the words in the sublist of expected + labels is in the plot legend + + Returns + ------ + Dictionary ... #TODO update this return statement + string that is the first entry in the internal list which label is + matched with. If no match is found, return value is None + """ + # TODO: return boolean instead of a none value - true if it matches, + # false if it does not match + + # + # for label_option in expected_labels: + # if label_option == label: + # return label_option + + label_check = labels.copy() + + # Iteratively test each label found in the plot legend to see if it is + # in the list of expected labels + # Implementing dictionaries here! + for i, a_label in enumerate(labels.keys()): + # print(a_label) + # for expected_label in expected_labels[i]: + # test = a_label in expected_label + # print(a_label, expected_label) + # print(test) + + label_check[a_label]["match"] = any( + a_label in expected_label + for expected_label in expected_labels[i] ) - # Check that each legend entry label is in one of all_label_options - assert len([val for val in legend_dict.values() if val]) == len( - all_label_options - ), "Incorrect legend labels" + # test2 = [a_label in expected_label + # for expected_label in expected_labels[i]] + # any(test2) + # print(expected_label) - # Create two copies of image array, one filled with the plot data class - # labels (im_data_labels) and the other with the expected labels - # (im_expected_labels) - im_class_dict = {} - for val in np.unique(im_data): - im_class_dict[val] = legend_dict[im_cmap(im.norm(val))] - im_data_labels = [ - [im_class_dict[val] for val in row] for row in im_data.data - ] - im_expected_labels = [ - [all_label_options[val][0] for val in row] for row in im_expected - ] - - # Check that expected and actual labels match up - assert np.array_equal( - im_data_labels, im_expected_labels - ), "Incorrect legend to data relation" + # for i, label in enumerate(labels): + # test_output = any( + # label in expected_label + # for expected_label in expected_labels[i] + # ) + # label_check[label] = test_output - # IMAGE TESTS/HELPER FUNCTIONS + return label_check def get_plot_image(self): """Returns images stored on the Axes object as a list of numpy arrays. @@ -172,13 +167,149 @@ def get_plot_image(self): """ im_data = [] if self.ax.get_images(): - im_data = self.ax.get_images()[0].get_array() + im = self.ax.get_images()[0] + im_data = im.get_array() + im_cmap = im.get_cmap() + + # TODO make this a better test (Try / except??) / return more + # expressive error assert list(im_data), "No Image Displayed" # If image array has 3 dims (e.g. rgb image), remove alpha channel if len(im_data.shape) == 3: im_data = im_data[:, :, :3] - return im_data + + return (im_data, im_cmap) + + def assert_raster_legend_labels(self, im_expected, expected_labels): + """Asserts legend correctly describes classified image on Axes ax, + checking the legend labels and the values + + Parameters + ---------- + im_expected: array of arrays with expected classified image on ax. + expected_labels: list of lists + Each sublist within the expected_labels list contains the word + or word variations expected to be found in the legend labels of + the plot being tested. Example list: [["gain", "increase"]] + would be provided if you wanted to test that the word "gain" OR + "increase" were found in the first legend element. + TODO: i think it's or but let's just clarify it's not "and" + TODO: we should have tests that check what happens if someone + provides only 2 sublist but there are three legend labels. + Sublists must be in the same order as the legend elements are + in. EXAMPLE: the first sublist will map to the first labeled item + in a plot legend. + + Returns + ---------- + Nothing (if checks pass) or raises assertion error + """ + + # TODO add test for a plot with no image in it. get_plot_image should + # return an error + + # Retrieve image array + im_data, im_cmap = self.get_plot_image() + + # TODO: We shouldn't need these tests because they happen in + # get_plot_image already. But we should improve the output message in + # get_plot image to be something more expressive + # assert list(im_data), "No Image Displayed" + + labels = self.get_legend_labels() + + # TODO: i think this should be a try, catch / return value error + # this still works as a dictionary as there is 3 keys + assert len(labels) == len(expected_labels), ( + "Number of label options provided doesn't match the number of" + " labels found in the image." + ) + + # TODO: this currently only returns a list of values. It would be + # better if it returns a dictionary with the key being each + # label being tested and the value being a boolean (True if there + # is a match, False if there is no match) + # TODO: UPDATE CKECK LABEL to take input dictionary rather than list + labels_dict = self._check_label(labels, expected_labels) + + # labels_check = [ + # self._check_label(label, expected_labels[i]) + # for i, label in enumerate(labels) + # ] + + # TODO: fix the dict comprehension below to grab the correct key for + # true / false + + # Pull out any labels that failed the above test for final printing + # below + bad_labels = { + key: labels_dict[key] + for key in labels_dict + if not labels_dict[key]["match"] + } + + # TODO: raise assertion error (value error?) and print out a list of + # labels that are wrong ONLY if some are wrong + if bad_labels: + # get just the labels that are + bad_keys = [str(a_key) for a_key in bad_labels.keys()] + raise ValueError( + "Oops. It looks like atleast one of your legend " + "labels is incorrect. Double check the " + "following label(s): {" + "}".format(bad_keys) + ) + + # Check that each legend entry label is in one of expected_labels + # assert all( + # labels_check + # ), "Provided legend labels don't match labels found." + + # TODO: once we get the above working, let's then add another layer + # where we grab the RGB values and also add that to the dictionary + # in the above we have the color and the label. now we need another + # dictionary that has the array value and map that to color. + + # At that point we can test whether the colors in the plot array, map + # to the legend patch colors and in turn the expected image + + # Get image -- this can be a helper... + # TODO: remember how cmaps map to data in a np array. i believe we + # can pull from the earthpy legend function to help with this. + # We will need to know whether the vmin and vmax are modified in the + # plot i think as well... this could get tricky... + + # BEGIN WIP + # Essential what this should do is grab the colors used in the plot + # for each unique array value. NOTE that we may have to consider both + # continuous and non continuous data here (so arrays with 123, + # 012 or 0,4,7 as examples) We will need tests for all cases. + + # im_class_dict = {} + for val in np.unique(im_data): + print(val) + # We may need to handle a list different from an existin gcmap + # cmap_type = im_cmap.name + # im_class_dict[val] = legend_dict[im_cmap(im.norm(val))] + + # im_data_labels = [ + # [im_class_dict[val] for val in row] for row in im_data.data + # ] + # im_expected_labels = [ + # [all_label_options[val][0] for val in row] for row in im_expected + # ] + # END WIP + + # Check that expected and actual arrays data match up + assert np.array_equal( + im_data, im_expected + ), "Expected image data doesn't match data in image." + + # TODO: warning -- proj_create: init=epsg:/init=IGNF: syntax not + # supported in non-PROJ4 emulation mode - where is this coming from? + + # IMAGE TESTS/HELPER FUNCTIONS def assert_image( self, im_expected, im_classified=False, m="Incorrect Image Displayed" @@ -200,6 +331,9 @@ def assert_image( ---------- Nothing (if checks pass) or raises error """ + # TODO this should be able to call the get_image helper above rather + # than recreating get image. + im_data = [] if self.ax.get_images(): im_data = self.ax.get_images()[0].get_array() diff --git a/matplotcheck/tests/test_raster.py b/matplotcheck/tests/test_raster.py index 7c4d7fd8..33fcd9cd 100644 --- a/matplotcheck/tests/test_raster.py +++ b/matplotcheck/tests/test_raster.py @@ -157,15 +157,23 @@ def test_raster_assert_colorbar_range_blank(raster_plt_blank, np_ar): """ LEGEND TESTS """ +def test_get_legend_labels_accuracy(raster_plt_class, np_ar_discrete): + """Checks that helper function get_legend_labels returns the right labels. + """ + values = np.sort(np.unique(np_ar_discrete)) + label_options = ["level " + str(i) for i in values] + + assert label_options == raster_plt_class.get_legend_labels() + plt.close() + + def test_raster_assert_legend_accuracy(raster_plt_class, np_ar_discrete): """Checks that legend matches image, checking both the labels and color patches""" values = np.sort(np.unique(np_ar_discrete)) - label_options = [[str(i)] for i in values] + label_options = [["level " + str(i)] for i in values] - raster_plt_class.assert_legend_accuracy_classified_image( - np_ar_discrete, label_options - ) + raster_plt_class.assert_raster_legend_labels(np_ar_discrete, label_options) plt.close() @@ -177,8 +185,8 @@ def test_raster_assert_legend_accuracy_badlabel( # Should fail with bad label bad_label_options = [["foo"] * values.shape[0]] - with pytest.raises(AssertionError, match="Incorrect legend labels"): - raster_plt_class.assert_legend_accuracy_classified_image( + with pytest.raises(AssertionError, match="Number of label options provid"): + raster_plt_class.assert_raster_legend_labels( np_ar_discrete, bad_label_options ) plt.close() @@ -198,11 +206,9 @@ def test_raster_assert_legend_accuracy_badvalues( # Should fail with bad image with pytest.raises( - AssertionError, match="Incorrect legend to data relation" + AssertionError, match="Provided legend labels don't match labels found" ): - raster_plt_class.assert_legend_accuracy_classified_image( - bad_image, label_options - ) + raster_plt_class.assert_raster_legend_labels(bad_image, label_options) plt.close() @@ -213,9 +219,7 @@ def test_raster_assert_legend_accuracy_nolegend(raster_plt, np_ar_discrete): # Fails without legend with pytest.raises(AssertionError, match="No legend displayed"): - raster_plt.assert_legend_accuracy_classified_image( - np_ar_discrete, label_options - ) + raster_plt.assert_raster_legend_labels(np_ar_discrete, label_options) plt.close() @@ -228,7 +232,7 @@ def test_raster_assert_legend_accuracy_noimage( # Fails when no image displayed with pytest.raises(AssertionError, match="No Image Displayed"): - raster_plt_blank.assert_legend_accuracy_classified_image( + raster_plt_blank.assert_raster_legend_labels( np_ar_discrete, label_options ) plt.close()