Skip to content

Commit 8a51363

Browse files
author
Bruce J Palmer
committed
Modified scatter test to test all index and data types.
1 parent 26f603d commit 8a51363

File tree

1 file changed

+91
-21
lines changed

1 file changed

+91
-21
lines changed

xga/testing/scatter_unit_test.cpp

Lines changed: 91 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,48 +4,45 @@
44
#include <iostream>
55

66
#define DIM 256
7-
int main(int argc, char **argv)
7+
template<typename idx_type, typename data_type>
8+
void scatter_test()
89
{
9-
XGA::Environment *env = XGA::Environment::instance(&argc,&argv);
10+
XGA::Environment *env = XGA::Environment::instance();
1011
XGA::Group *group = env->getWorldGroup();
1112
int rank = group->rank();
1213
int size = group->size();
1314
int wrank;
1415
MPI_Comm_rank(MPI_COMM_WORLD,&wrank);
1516
/* Create global array */
1617
int ndim = 3;
17-
int64_t dims[3];
18+
idx_type dims[3];
1819
dims[0] = DIM;
1920
dims[1] = DIM;
2021
dims[2] = DIM;
21-
if (rank == 0) {
22-
printf("\nTesting SCATTER on a %d x %d x %d array",dims[0],dims[1],dims[2]);
23-
printf(" running on %d processors\n",size);
24-
}
25-
XGA::GlobalArray<double> ga(group, ndim, dims);
22+
XGA::GlobalArray<data_type> ga(group, ndim, dims);
2623
ga.allocate();
2724

2825
/* initialize global array using scatter */
29-
int64_t nelems = static_cast<int64_t>(static_cast<double>(dims[0]*dims[1]*dims[2])
26+
idx_type nelems = static_cast<idx_type>(static_cast<double>(dims[0]*dims[1]*dims[2])
3027
/ static_cast<double>(size))+1;
31-
int64_t total = dims[0]*dims[1]*dims[2];
32-
double *values = new double[nelems];
33-
int64_t *subscripts = new int64_t[nelems*ndim];
34-
int64_t i, j, k, n, idx, icnt;
28+
idx_type total = dims[0]*dims[1]*dims[2];
29+
data_type *values = new data_type[nelems];
30+
idx_type *subscripts = new idx_type[nelems*ndim];
31+
idx_type i, j, k, n, idx, icnt;
3532
icnt = 0;
3633
for (n=rank; n<total; n+=size) {
3734
idx = n;
3835
k = idx%dims[2];
3936
idx = (idx-k)/dims[2];
4037
j = idx%dims[1];
4138
i = (idx-j)/dims[1];
42-
values[icnt] = static_cast<double>(n);
39+
values[icnt] = static_cast<data_type>(n);
4340
subscripts[ndim*icnt] = i;
4441
subscripts[ndim*icnt+1] = j;
4542
subscripts[ndim*icnt+2] = k;
4643
icnt++;
4744
}
48-
int64_t lo[3], hi[3], ld[2];
45+
idx_type lo[3], hi[3], ld[2];
4946
ga.distribution(rank,lo,hi);
5047
ga.scatter(values, subscripts, icnt);
5148
ga.sync();
@@ -54,24 +51,24 @@ int main(int argc, char **argv)
5451
/* Check values */
5552
nelems = (hi[0]-lo[0]+1)*(hi[1]-lo[1]+1)*(hi[2]-lo[2]+1);
5653
/* Access local data in array */
57-
int64_t idim = hi[0]-lo[0]+1;
58-
int64_t jdim = hi[1]-lo[1]+1;
59-
int64_t kdim = hi[2]-lo[2]+1;
54+
idx_type idim = hi[0]-lo[0]+1;
55+
idx_type jdim = hi[1]-lo[1]+1;
56+
idx_type kdim = hi[2]-lo[2]+1;
6057
void *vptr;
6158
ga.accessPtr(lo, hi, &vptr, ld);
62-
double *dptr = static_cast<double*>(vptr);
59+
data_type *dptr = static_cast<data_type*>(vptr);
6360
int ok = 1;
6461
int chk;
6562
for (i=0; i<idim; i++) {
6663
for (j=0; j<jdim; j++) {
6764
for (k=0; k<kdim; k++) {
6865
if (dptr[k+kdim*j+kdim*jdim*i]
69-
!= static_cast<double>(k+lo[2] + (j+lo[1])*dims[2]
66+
!= static_cast<data_type>(k+lo[2] + (j+lo[1])*dims[2]
7067
+ (i+lo[0])*dims[2]*dims[1])) {
7168
printf("p[%d] Check fails for i: %ld j: %ld k: %ld"
7269
" actual: %f expected: %f\n",
7370
wrank,i+lo[0],j+lo[1],k+lo[2],dptr[k+kdim*j+i*kdim*jdim],
74-
static_cast<double>(k+lo[2] + (j+lo[1])*dims[2]
71+
static_cast<data_type>(k+lo[2] + (j+lo[1])*dims[2]
7572
+ (i+lo[0])*dims[2]*dims[1]));
7673
ok = 0;
7774
}
@@ -87,6 +84,79 @@ int main(int argc, char **argv)
8784
printf("\n scatter test FAILS\n");
8885
}
8986
ga.clear();
87+
}
88+
int main(int argc, char **argv)
89+
{
90+
XGA::Environment *env = XGA::Environment::instance(&argc,&argv);
91+
XGA::Group *group = env->getWorldGroup();
92+
int rank = group->rank();
93+
int size = group->size();
94+
/* Create global array */
95+
if (rank == 0) {
96+
int ndim = 3;
97+
int64_t dims[3];
98+
dims[0] = DIM;
99+
dims[1] = DIM;
100+
dims[2] = DIM;
101+
printf("\nTesting SCATTER on a %d x %d x %d array",dims[0],dims[1],dims[2]);
102+
printf(" running on %d processors\n",size);
103+
}
104+
if (rank == 0) {
105+
printf("\nTesting SCATTER for ints and int64_t indices\n");
106+
}
107+
scatter_test<int64_t,int>();
108+
if (rank == 0) {
109+
printf("\nTesting SCATTER for longs and int64_t indices\n");
110+
}
111+
scatter_test<int64_t,long>();
112+
if (rank == 0) {
113+
printf("\nTesting SCATTER for long longs and int64_t indices\n");
114+
}
115+
scatter_test<int64_t,long long>();
116+
if (rank == 0) {
117+
printf("\nTesting SCATTER for floats and int64_t indices\n");
118+
}
119+
scatter_test<int64_t,float>();
120+
if (rank == 0) {
121+
printf("\nTesting SCATTER for doubles and int64_t indices\n");
122+
}
123+
scatter_test<int64_t,double>();
124+
if (rank == 0) {
125+
printf("\nTesting SCATTER for complex floats and int64_t indices\n");
126+
}
127+
scatter_test<int64_t,std::complex<float> >();
128+
if (rank == 0) {
129+
printf("\nTesting SCATTER for complex doubles and int64_t indices\n");
130+
}
131+
scatter_test<int64_t,std::complex<double> >();
132+
if (rank == 0) {
133+
printf("\nTesting SCATTER for ints and int indices\n");
134+
}
135+
scatter_test<int,int>();
136+
if (rank == 0) {
137+
printf("\nTesting SCATTER for longs and int indices\n");
138+
}
139+
scatter_test<int,long>();
140+
if (rank == 0) {
141+
printf("\nTesting SCATTER for long longs and int indices\n");
142+
}
143+
scatter_test<int,long long>();
144+
if (rank == 0) {
145+
printf("\nTesting SCATTER for floats and int indices\n");
146+
}
147+
scatter_test<int,float>();
148+
if (rank == 0) {
149+
printf("\nTesting SCATTER for doubles and int indices\n");
150+
}
151+
scatter_test<int,double>();
152+
if (rank == 0) {
153+
printf("\nTesting SCATTER for complex floats and int indices\n");
154+
}
155+
scatter_test<int,std::complex<float> >();
156+
if (rank == 0) {
157+
printf("\nTesting SCATTER for complex doubles and int indices\n");
158+
}
159+
scatter_test<int,std::complex<double> >();
90160
env->finalize();
91161
MPI_Finalize();
92162
return 0;

0 commit comments

Comments
 (0)