diff --git a/bindsnet/evaluation/evaluation.py b/bindsnet/evaluation/evaluation.py index cd4fe2a6..f2f44dfc 100644 --- a/bindsnet/evaluation/evaluation.py +++ b/bindsnet/evaluation/evaluation.py @@ -27,6 +27,7 @@ def assign_labels( :return: Tuple of class assignments, per-class spike proportions, and per-class firing rates. """ + n_neurons = spikes.size(2) if rates is None: @@ -34,29 +35,28 @@ def assign_labels( # Sum over time dimension (spike ordering doesn't matter). spikes = spikes.sum(1) - + for i in range(n_labels): + # Create mask (faster and allows future steps to stay on GPU). + mask = (labels == i) # Count the number of samples with this label. - n_labeled = torch.sum(labels == i).float() + n_labeled = mask.sum().float() if n_labeled > 0: - # Get indices of samples with this label. - indices = torch.nonzero(labels == i).view(-1) - - # Compute average firing rates for this label. - selected_spikes = torch.index_select( - spikes, dim=0, index=torch.tensor(indices) - ) - rates[:, i] = alpha * rates[:, i] + ( - torch.sum(selected_spikes, 0) / n_labeled - ) + # Get indices of samples with this label (masking is faster and stays on the GPU). + label_sum = spikes[mask].sum(0) + # Update rates. + rates[:, i] = alpha * rates[:, i] + (label_sum / n_labeled) - # Compute proportions of spike activity per class. - proportions = rates / rates.sum(1, keepdim=True) - proportions[proportions != proportions] = 0 # Set NaNs to 0 + # Compute proportions (and use 'torch.where' to avoid NaN bug). + total_activity = rates.sum(1, keepdim=True) + proportions = torch.where(total_activity > 0, rates / total_activity, torch.zeros_like(rates)) # Neuron assignments are the labels they fire most for. - assignments = torch.max(proportions, 1)[1] + max_vals, assignments = torch.max(proportions, 1) + + # Set unassigned (silent) neurons to -1 instead of defaulting to 0. + assignments[max_vals == 0] = -1 return assignments, proportions, rates