Skip to content

Commit 2cb2ad0

Browse files
committed
Add safe-guard for masked version
1 parent d80ed92 commit 2cb2ad0

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

include/graphblas/reference/blas3.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,15 +1327,19 @@ namespace grb {
13271327
for( auto i = start_row; i < end_row; ++i ) {
13281328
auto mask_k = mask_raw.col_start[ i ];
13291329
for( auto k = A_crs_raw.col_start[ i ]; k < A_crs_raw.col_start[ i + 1 ]; ++k ) {
1330-
auto k_col = A_crs_raw.row_index[ k ];
1330+
const auto j = A_crs_raw.row_index[ k ];
13311331

13321332
// Increment the mask pointer until we find the right column, or a lower column (since the storage withing a row is sorted in a descending order)
1333-
while( mask_k < mask_raw.col_start[ i + 1 ] && mask_raw.row_index[ mask_k ] > k_col ) {
1333+
while( mask_k < mask_raw.col_start[ i + 1 ] && mask_raw.row_index[ mask_k ] > j ) {
13341334
_DEBUG_THREADESAFE_PRINT( "NEquals masked coordinate: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ mask_k ] ) + " )\n" );
13351335
mask_k++;
13361336
}
1337+
if( mask_k >= mask_raw.col_start[ i + 1 ] ) {
1338+
_DEBUG_THREADESAFE_PRINT( "No value left for this column\n" );
1339+
break;
1340+
}
13371341

1338-
if( mask_raw.row_index[ mask_k ] < k_col || not MaskHasValue< MaskType >( mask_raw, mask_k ).value ) {
1342+
if( mask_raw.row_index[ mask_k ] < j || not MaskHasValue< MaskType >( mask_raw, mask_k ).value ) {
13391343
mask_k++;
13401344
_DEBUG_THREADESAFE_PRINT( "Skip masked value at: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ mask_k ] ) + " )\n" );
13411345
continue;
@@ -1344,7 +1348,7 @@ namespace grb {
13441348
_DEBUG_THREADESAFE_PRINT( "Found masked value at: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ mask_k ] ) + " )\n" );
13451349
// Get A value
13461350
const auto a_val_before = A_crs_raw.values[ k ];
1347-
_DEBUG_THREADESAFE_PRINT( "A( " + std::to_string( i ) + ";" + std::to_string( k_col ) + " ) = " + std::to_string( a_val_before ) + "\n" );
1351+
_DEBUG_THREADESAFE_PRINT( "A( " + std::to_string( i ) + ";" + std::to_string( j ) + " ) = " + std::to_string( a_val_before ) + "\n" );
13481352
// Compute the fold for this coordinate
13491353
local_rc = local_rc ? local_rc : grb::apply< descr >( A_crs_raw.values[ k ], a_val_before, x, op );
13501354
local_rc = local_rc ? local_rc : grb::apply< descr >( A_ccs_raw.values[ k ], a_val_before, x, op );

0 commit comments

Comments
 (0)