-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Add 3D support and confusion matrix output to PanopticQualityMetric #8684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,6 +88,27 @@ | |
| [torch.as_tensor([[0.0, 1.0, 0.0], [0.6667, 0.0, 0.4]]), torch.as_tensor([[0.0, 0.5, 0.0], [0.3333, 0.0, 0.4]])], | ||
| ] | ||
|
|
||
| # 3D test cases | ||
| sample_3d_pred = torch.as_tensor( | ||
| [[[[[2, 0], [1, 1]], [[0, 1], [2, 1]]], [[[0, 1], [3, 0]], [[1, 0], [1, 1]]]]], # instance channel # class channel | ||
| device=_device, | ||
| ) | ||
|
|
||
| sample_3d_gt = torch.as_tensor( | ||
| [[[[[2, 0], [0, 0]], [[2, 2], [2, 3]]], [[[3, 3], [3, 2]], [[2, 2], [3, 3]]]]], # instance channel # class channel | ||
| device=_device, | ||
| ) | ||
|
|
||
| # test 3D sample, num_classes = 3, match_iou_threshold = 0.5 | ||
| TEST_3D_CASE_1 = [{"num_classes": 3, "match_iou_threshold": 0.5}, sample_3d_pred, sample_3d_gt] | ||
|
|
||
| # test confusion matrix return | ||
| TEST_CM_CASE_1 = [ | ||
| {"num_classes": 3, "match_iou_threshold": 0.5, "return_confusion_matrix": True}, | ||
| sample_3_pred, | ||
| sample_3_gt, | ||
| ] | ||
|
|
||
|
|
||
| @SkipIfNoModule("scipy.optimize") | ||
| class TestPanopticQualityMetric(unittest.TestCase): | ||
|
|
@@ -108,6 +129,98 @@ def test_value_class(self, input_params, y_pred, y_gt, expected_value): | |
| else: | ||
| np.testing.assert_allclose(outputs.cpu().numpy(), np.asarray(expected_value), atol=1e-4) | ||
|
|
||
| def test_3d_support(self): | ||
| """Test that 3D input is properly supported.""" | ||
| input_params, y_pred, y_gt = TEST_3D_CASE_1 | ||
| metric = PanopticQualityMetric(**input_params) | ||
| # Should not raise an error for 3D input | ||
| metric(y_pred, y_gt) | ||
| outputs = metric.aggregate() | ||
| # Check that output is a tensor | ||
| self.assertIsInstance(outputs, torch.Tensor) | ||
| # Check that output shape is correct (num_classes,) | ||
| self.assertEqual(outputs.shape, torch.Size([3])) | ||
|
|
||
| def test_confusion_matrix_return(self): | ||
| """Test that confusion matrix can be returned instead of computed metrics.""" | ||
| input_params, y_pred, y_gt = TEST_CM_CASE_1 | ||
| metric = PanopticQualityMetric(**input_params) | ||
| metric(y_pred, y_gt) | ||
| outputs = metric.aggregate() | ||
| # Check that output is a tensor with shape (batch_size, num_classes, 4) | ||
| self.assertIsInstance(outputs, torch.Tensor) | ||
| self.assertEqual(outputs.shape[-1], 4) | ||
| # Verify that values correspond to [tp, fp, fn, iou_sum] | ||
| tp, fp, fn, iou_sum = outputs[..., 0], outputs[..., 1], outputs[..., 2], outputs[..., 3] | ||
| # tp, fp, fn should be non-negative integers | ||
| self.assertTrue(torch.all(tp >= 0)) | ||
| self.assertTrue(torch.all(fp >= 0)) | ||
| self.assertTrue(torch.all(fn >= 0)) | ||
| # iou_sum should be non-negative float | ||
| self.assertTrue(torch.all(iou_sum >= 0)) | ||
|
|
||
| def test_compute_mean_iou(self): | ||
| """Test mean IoU computation from confusion matrix.""" | ||
| from monai.metrics.panoptic_quality import compute_mean_iou | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this be imported at the top of the file as normal? |
||
|
|
||
| input_params, y_pred, y_gt = TEST_CM_CASE_1 | ||
| metric = PanopticQualityMetric(**input_params) | ||
| metric(y_pred, y_gt) | ||
| confusion_matrix = metric.aggregate() | ||
| mean_iou = compute_mean_iou(confusion_matrix) | ||
| # Check shape is correct | ||
| self.assertEqual(mean_iou.shape, confusion_matrix.shape[:-1]) | ||
| # Check values are non-negative | ||
| self.assertTrue(torch.all(mean_iou >= 0)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this test should compare the output against a known ground truth so that the actual computed value is tested to be correct. |
||
|
|
||
| def test_metric_name_filtering(self): | ||
| """Test that metric_name parameter properly filters output.""" | ||
| # Test single metric "sq" | ||
| metric_sq = PanopticQualityMetric(num_classes=3, metric_name="sq", match_iou_threshold=0.5) | ||
| metric_sq(sample_3_pred, sample_3_gt) | ||
| result_sq = metric_sq.aggregate() | ||
| self.assertIsInstance(result_sq, torch.Tensor) | ||
| self.assertEqual(result_sq.shape, torch.Size([3])) | ||
|
|
||
| # Test single metric "rq" | ||
| metric_rq = PanopticQualityMetric(num_classes=3, metric_name="rq", match_iou_threshold=0.5) | ||
| metric_rq(sample_3_pred, sample_3_gt) | ||
| result_rq = metric_rq.aggregate() | ||
| self.assertIsInstance(result_rq, torch.Tensor) | ||
| self.assertEqual(result_rq.shape, torch.Size([3])) | ||
|
|
||
| # Results should be different for different metrics | ||
| self.assertFalse(torch.allclose(result_sq, result_rq, atol=1e-4)) | ||
|
|
||
| def test_invalid_3d_shape(self): | ||
| """Test that invalid 3D shapes are rejected.""" | ||
| # Shape with 3 dimensions should fail | ||
| invalid_pred = torch.randint(0, 5, (2, 2, 10)) | ||
| invalid_gt = torch.randint(0, 5, (2, 2, 10)) | ||
| metric = PanopticQualityMetric(num_classes=3) | ||
| with self.assertRaises(ValueError): | ||
| metric(invalid_pred, invalid_gt) | ||
|
|
||
| # Shape with 6 dimensions should fail | ||
| invalid_pred = torch.randint(0, 5, (1, 2, 8, 8, 8, 8)) | ||
| invalid_gt = torch.randint(0, 5, (1, 2, 8, 8, 8, 8)) | ||
| with self.assertRaises(ValueError): | ||
| metric(invalid_pred, invalid_gt) | ||
|
|
||
| def test_compute_mean_iou_invalid_shape(self): | ||
| """Test that compute_mean_iou raises ValueError for invalid shapes.""" | ||
| from monai.metrics.panoptic_quality import compute_mean_iou | ||
|
|
||
| # Shape (..., 3) instead of (..., 4) should fail | ||
| invalid_confusion_matrix = torch.zeros(3, 3) | ||
| with self.assertRaises(ValueError): | ||
| compute_mean_iou(invalid_confusion_matrix) | ||
|
|
||
| # Shape (..., 5) should also fail | ||
| invalid_confusion_matrix = torch.zeros(2, 5) | ||
| with self.assertRaises(ValueError): | ||
| compute_mean_iou(invalid_confusion_matrix) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see where this function is being used, was it added just as an added utility?