|
11 | 11 | # adapted from https://github.com/HobbitLong/SupContrast |
12 | 12 | # modified for multi-supcon |
13 | 13 | class MultiSupConLoss(GenericPairLoss): |
| 14 | + """ |
| 15 | + Args: |
| 16 | + num_classes: number of classes |
| 17 | + temperature: temperature for scaling the similarity matrix |
| 18 | + threshold: threshold for jaccard similarity |
| 19 | + |
| 20 | + Inputs: |
| 21 | + embeddings: tensor of size (batch_size, embedding_size) |
| 22 | + labels: tensor of size (batch_size, num_classes) |
| 23 | + each row is a binary vector of size num_classes that only has 1s for the positive |
| 24 | + labels, and 0s for the negative labels |
| 25 | + indices_tuple: tuple of size 4 for triplets (anchors, positives, negatives, jaccard_matrix) |
| 26 | + or size 5 for pairs (anchor1, postives, anchor2, negativesm, jaccard_matrix) |
| 27 | + Can also be left as None |
| 28 | + ref_emb: tensor of size (batch_size, embedding_size) |
| 29 | + """ |
14 | 30 | def __init__(self, num_classes, temperature=0.1, threshold=0.3, **kwargs): |
15 | 31 | super().__init__(mat_based_loss=True, **kwargs) |
16 | 32 | self.temperature = temperature |
@@ -77,10 +93,13 @@ def forward( |
77 | 93 | """ |
78 | 94 | Args: |
79 | 95 | embeddings: tensor of size (batch_size, embedding_size) |
80 | | - labels: tensor of size (batch_size) |
81 | | - indices_tuple: tuple of size 3 for triplets (anchors, positives, negatives) |
82 | | - or size 4 for pairs (anchor1, postives, anchor2, negatives) |
| 96 | + labels: tensor of size (batch_size, num_classes) |
| 97 | + each row is a binary vector of size num_classes that only has 1s for the positive |
| 98 | + labels, and 0s for the negative labels |
| 99 | + indices_tuple: tuple of size 4 for triplets (anchors, positives, negatives, jaccard_matrix) |
| 100 | + or size 5 for pairs (anchor1, postives, anchor2, negativesm, jaccard_matrix) |
83 | 101 | Can also be left as None |
| 102 | + ref_emb: tensor of size (batch_size, embedding_size) |
84 | 103 | Returns: the loss |
85 | 104 | """ |
86 | 105 | self.reset_stats() |
|
0 commit comments