Skip to content

Commit 0868fc3

Browse files
authored
Merge pull request #66 from jacobmerson/feature-scalarMatMult
Create ScalarMatMultiply
2 parents 8a0cbd8 + 06d5642 commit 0868fc3

File tree

14 files changed

+520
-50
lines changed

14 files changed

+520
-50
lines changed

core/lasCSRCore.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,34 @@ namespace las
4040
}
4141
}
4242
};
43+
class IdentityCSR : public CSRBuilder
44+
{
45+
protected:
46+
int ndofs;
47+
public:
48+
IdentityCSR(int ndofs)
49+
: CSRBuilder(ndofs,ndofs)
50+
, ndofs(ndofs)
51+
{
52+
}
53+
void run()
54+
{
55+
for(int i=0; i<ndofs;++i)
56+
{
57+
add(i,i);
58+
}
59+
}
60+
};
4361
Sparsity * createCSR(apf::Numbering * num, int ndofs)
4462
{
4563
CSRFromNumbering bldr(num,ndofs);
4664
bldr.run();
4765
return bldr.finalize();
4866
}
67+
Sparsity * createIdentityCSR(int ndofs)
68+
{
69+
IdentityCSR bldr(ndofs);
70+
bldr.run();
71+
return bldr.finalize();
72+
}
4973
}

core/lasCSRCore.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ namespace las
1010
* values of the numbering, no mesh partitioning is considered.
1111
*/
1212
Sparsity * createCSR(apf::Numbering * num, int ndofs);
13+
Sparsity * createIdentityCSR(int ndofs);
1314
}
1415
#endif

src/las.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ namespace las
7171
{
7272
static_cast<T*>(this)->_set(v,cnt,rws,vls);
7373
}
74+
void set(Vec * v, scalar * vls)
75+
{
76+
static_cast<T*>(this)->_set(v,vls);
77+
}
7478
void set(Mat * m, int cntr, int * rws, int cntc, int * cls, scalar * vls)
7579
{
7680
static_cast<T*>(this)->_set(m,cntr,rws,cntc,cls,vls);
@@ -198,6 +202,8 @@ namespace las
198202
virtual void solve(Mat * k, Vec * u, Vec * f) = 0;
199203
virtual ~Solve() {}
200204
};
205+
template <typename T>
206+
Solve * getSolve(int id);
201207
/**
202208
* Interface for Matrix-Vector multiplication
203209
* @todo Retrieve backend-specific solvers using
@@ -210,6 +216,8 @@ namespace las
210216
virtual void exec(Mat * x, Vec * a, Vec * b) = 0;
211217
virtual ~MatVecMult() {}
212218
};
219+
template <class T>
220+
MatVecMult * getMatVecMult();
213221
/**
214222
* Interface for Matrix-Matrix multiplication
215223
* @todo Retrieve backend-specific solvers using
@@ -222,6 +230,56 @@ namespace las
222230
virtual void exec(Mat * a, Mat * b, Mat ** c) = 0;
223231
virtual ~MatMatMult() {}
224232
};
233+
template <class T>
234+
MatMatMult * getMatMatMult();
235+
/**
236+
* Interface for Scalar-Matrix multiplication
237+
* If c is NULL performs an in place multiplication
238+
* @todo Retrieve backend-specific solvers using
239+
* backend id classes to do template
240+
* specialization, as above.
241+
*/
242+
class ScalarMatMult
243+
{
244+
public:
245+
virtual void exec(scalar s, Mat * a, Mat ** c) = 0;
246+
virtual ~ScalarMatMult() {}
247+
};
248+
template <class T>
249+
ScalarMatMult * getScalarMatMult();
250+
/*
251+
* interface for C = alpha_1*A+alpha_2*B
252+
*/
253+
class MatMatAdd
254+
{
255+
public:
256+
virtual void exec(scalar s1, Mat * a, scalar s2, Mat * b, Mat ** c) = 0;
257+
virtual ~MatMatAdd() {}
258+
};
259+
template <class T>
260+
MatMatAdd * getMatMatAdd();
261+
/*
262+
* interface for vector vector addition
263+
*/
264+
class VecVecAdd
265+
{
266+
public:
267+
virtual void exec(scalar s1, Vec * v1, scalar s2, Vec * v2, Vec *& v3) = 0;
268+
virtual ~VecVecAdd() {}
269+
};
270+
template <class T>
271+
VecVecAdd * getVecVecAdd();
272+
/*
273+
* Interface for scalar-vector multiplication
274+
*/
275+
class ScalarVecMult
276+
{
277+
public:
278+
virtual void exec(scalar s, Vec * x, Vec ** y) = 0;
279+
virtual ~ScalarVecMult() {}
280+
};
281+
template <class T>
282+
ScalarVecMult * getScalarVecMult();
225283
/*
226284
* Finalize routines which must be called on a matrix when switching from
227285
* add mode to set mode

src/lasCSR.h

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "lasSys.h"
44
#include <vector>
55
#include <iostream>
6+
#include <algorithm>
7+
#include <cassert>
68
namespace las
79
{
810
class Sparsity;
@@ -35,21 +37,29 @@ namespace las
3537
* they are converted to use 1-indexing (in a debug build this generates a warning).
3638
*/
3739
CSR(int r, int c, int nnz, int * rs, int * cs);
40+
CSR(int r, int c, int nnz, std::vector<int> const & rs, std::vector<int> const & cs);
3841
int getNumRows() const { return nr; }
3942
int getNumCols() const { return nc; }
4043
int getNumNonzero() const { return nnz; }
44+
// return the index into the values array
45+
// if the location is not stored then return -1
46+
// note rw and cl start at zero
4147
int operator()(int rw, int cl) const
4248
{
43-
int result = -1;
44-
int fst = rws[rw] - 1;
45-
while((fst < rws[rw+1] - 2) && (cls[fst] - 1 < cl))
46-
++fst;
47-
// the column is correct at offset and the row isn't empty
48-
if(cls[fst] - 1 == cl && rws[rw] - 1 <= rws[rw+1] - 2)
49-
result = fst;
50-
else
51-
result = -1;
52-
return result;
49+
assert(rw < nr && rw>=0);
50+
assert(cl < nc && cl>=0);
51+
// the row is empty
52+
if(rws[rw+1]-rws[rw] == 0)
53+
return -1;
54+
// this approach finds the correct index in log(n) time where
55+
// n is the number of elements on the row
56+
typedef std::vector<int>::const_iterator vit_t;
57+
vit_t bgn = cls.begin()+rws[rw]-1;
58+
vit_t end = cls.begin()+rws[rw+1]-1;
59+
std::pair<vit_t, vit_t> bounds = equal_range(bgn, end, cl+1);
60+
if(bounds.first == bounds.second)
61+
return -1;
62+
return bounds.first-cls.begin();
5363
}
5464
int * getRows() { return &rws[0]; }
5565
int * getCols() { return &cls[0]; }

src/lasCSRBuilder_impl.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ namespace las
2929
auto it = std::unique(coords.begin(), coords.end(), unique_comps<int,int>);
3030
coords.resize(std::distance(coords.begin(), it));
3131
assert(nnz >= coords.size());
32-
if(coords.size() < nnz) {
33-
std::cerr<<"Warning: ignored "<<nnz-coords.size() << " duplicate entries\n";
34-
}
32+
// this is useful for debugging, but doesn't make sense when the builder is needed
33+
// for things like the add function...
34+
//if(coords.size() < nnz) {
35+
// std::cerr<<"Warning: ignored "<<nnz-coords.size() << " duplicate entries\n";
36+
//}
3537
nnz = coords.size();
3638
cls.resize(nnz);
3739
for(std::size_t i=0; i<coords.size(); ++i) {
@@ -61,7 +63,7 @@ namespace las
6163
result = false;
6264
if(rw>=rw_cnt || cl>=cl_cnt)
6365
{
64-
std::cerr<<"Warning: inserting a row outside the matrix bounds. Skipping it.\n";
66+
std::cerr<<"Warning: inserting a row/column outside the matrix bounds ("<<rw<<","<<cl<<"). Skipping it.\n";
6567
result=false;
6668
}
6769
return result;

src/lasCSR_impl.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ namespace las
2323
cls[cl]++;
2424
}
2525
}
26+
LAS_INLINE CSR::CSR(int r, int c, int nnz, std::vector<int> const & rs, std::vector<int> const & cs)
27+
: nr(r)
28+
, nc(c)
29+
, nnz(nnz)
30+
, rws(rs)
31+
, cls(cs)
32+
{
33+
}
2634
LAS_INLINE Sparsity * csrFromArray(int rws, int cls, int nnz, int * row_arr, int * col_arr)
2735
{
2836
return reinterpret_cast<Sparsity*>(new CSR(rws,cls,nnz,row_arr,col_arr));

src/lasPETSc.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ namespace las
1717
Solve * createPetscQNSolve(void * a);
1818
MatVecMult * createPetscMatVecMult();
1919
MatMatMult * createPetscMatMatMult();
20+
ScalarMatMult * createPetscScalarMatMult();
2021
template <>
2122
void finalizeMatrix<petsc>(Mat * mat);
2223
template <>

src/lasPETSc_impl.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,28 @@ namespace las
482482
{
483483
return new PetscMatMatMult;
484484
}
485+
class PetscScalarMatMult : public ScalarMatMult
486+
{
487+
public:
488+
virtual void exec(scalar s, Mat * a, Mat ** c)
489+
{
490+
if (c == nullptr)
491+
{
492+
PetscErrorCode ierr = ::MatScale(*getPetscMat(a), s);
493+
}
494+
else
495+
{
496+
std::cerr << "Out of place matrix scalar multiplication not "
497+
"implemented in petsc"
498+
<< std::endl;
499+
std::abort();
500+
}
501+
}
502+
};
503+
LAS_INLINE ScalarMatMult * createPetscScalarMatMult()
504+
{
505+
return new PetscScalarMatMult;
506+
}
485507
template <>
486508
LAS_INLINE void finalizeMatrix<petsc>(Mat * mat)
487509
{

0 commit comments

Comments
 (0)