44#include < iostream>
55
66#define DIM 256
7- int main (int argc, char **argv)
7+ template <typename idx_type, typename data_type>
8+ void scatter_test ()
89{
9- XGA::Environment *env = XGA::Environment::instance (&argc,&argv );
10+ XGA::Environment *env = XGA::Environment::instance ();
1011 XGA::Group *group = env->getWorldGroup ();
1112 int rank = group->rank ();
1213 int size = group->size ();
1314 int wrank;
1415 MPI_Comm_rank (MPI_COMM_WORLD,&wrank);
1516 /* Create global array */
1617 int ndim = 3 ;
17- int64_t dims[3 ];
18+ idx_type dims[3 ];
1819 dims[0 ] = DIM;
1920 dims[1 ] = DIM;
2021 dims[2 ] = DIM;
21- if (rank == 0 ) {
22- printf (" \n Testing SCATTER on a %d x %d x %d array" ,dims[0 ],dims[1 ],dims[2 ]);
23- printf (" running on %d processors\n " ,size);
24- }
25- XGA::GlobalArray<double > ga (group, ndim, dims);
22+ XGA::GlobalArray<data_type> ga (group, ndim, dims);
2623 ga.allocate ();
2724
2825 /* initialize global array using scatter */
29- int64_t nelems = static_cast <int64_t >(static_cast <double >(dims[0 ]*dims[1 ]*dims[2 ])
26+ idx_type nelems = static_cast <idx_type >(static_cast <double >(dims[0 ]*dims[1 ]*dims[2 ])
3027 / static_cast <double >(size))+1 ;
31- int64_t total = dims[0 ]*dims[1 ]*dims[2 ];
32- double *values = new double [nelems];
33- int64_t *subscripts = new int64_t [nelems*ndim];
34- int64_t i, j, k, n, idx, icnt;
28+ idx_type total = dims[0 ]*dims[1 ]*dims[2 ];
29+ data_type *values = new data_type [nelems];
30+ idx_type *subscripts = new idx_type [nelems*ndim];
31+ idx_type i, j, k, n, idx, icnt;
3532 icnt = 0 ;
3633 for (n=rank; n<total; n+=size) {
3734 idx = n;
3835 k = idx%dims[2 ];
3936 idx = (idx-k)/dims[2 ];
4037 j = idx%dims[1 ];
4138 i = (idx-j)/dims[1 ];
42- values[icnt] = static_cast <double >(n);
39+ values[icnt] = static_cast <data_type >(n);
4340 subscripts[ndim*icnt] = i;
4441 subscripts[ndim*icnt+1 ] = j;
4542 subscripts[ndim*icnt+2 ] = k;
4643 icnt++;
4744 }
48- int64_t lo[3 ], hi[3 ], ld[2 ];
45+ idx_type lo[3 ], hi[3 ], ld[2 ];
4946 ga.distribution (rank,lo,hi);
5047 ga.scatter (values, subscripts, icnt);
5148 ga.sync ();
@@ -54,24 +51,24 @@ int main(int argc, char **argv)
5451 /* Check values */
5552 nelems = (hi[0 ]-lo[0 ]+1 )*(hi[1 ]-lo[1 ]+1 )*(hi[2 ]-lo[2 ]+1 );
5653 /* Access local data in array */
57- int64_t idim = hi[0 ]-lo[0 ]+1 ;
58- int64_t jdim = hi[1 ]-lo[1 ]+1 ;
59- int64_t kdim = hi[2 ]-lo[2 ]+1 ;
54+ idx_type idim = hi[0 ]-lo[0 ]+1 ;
55+ idx_type jdim = hi[1 ]-lo[1 ]+1 ;
56+ idx_type kdim = hi[2 ]-lo[2 ]+1 ;
6057 void *vptr;
6158 ga.accessPtr (lo, hi, &vptr, ld);
62- double *dptr = static_cast <double *>(vptr);
59+ data_type *dptr = static_cast <data_type *>(vptr);
6360 int ok = 1 ;
6461 int chk;
6562 for (i=0 ; i<idim; i++) {
6663 for (j=0 ; j<jdim; j++) {
6764 for (k=0 ; k<kdim; k++) {
6865 if (dptr[k+kdim*j+kdim*jdim*i]
69- != static_cast <double >(k+lo[2 ] + (j+lo[1 ])*dims[2 ]
66+ != static_cast <data_type >(k+lo[2 ] + (j+lo[1 ])*dims[2 ]
7067 + (i+lo[0 ])*dims[2 ]*dims[1 ])) {
7168 printf (" p[%d] Check fails for i: %ld j: %ld k: %ld"
7269 " actual: %f expected: %f\n " ,
7370 wrank,i+lo[0 ],j+lo[1 ],k+lo[2 ],dptr[k+kdim*j+i*kdim*jdim],
74- static_cast <double >(k+lo[2 ] + (j+lo[1 ])*dims[2 ]
71+ static_cast <data_type >(k+lo[2 ] + (j+lo[1 ])*dims[2 ]
7572 + (i+lo[0 ])*dims[2 ]*dims[1 ]));
7673 ok = 0 ;
7774 }
@@ -87,6 +84,79 @@ int main(int argc, char **argv)
8784 printf (" \n scatter test FAILS\n " );
8885 }
8986 ga.clear ();
87+ }
88+ int main (int argc, char **argv)
89+ {
90+ XGA::Environment *env = XGA::Environment::instance (&argc,&argv);
91+ XGA::Group *group = env->getWorldGroup ();
92+ int rank = group->rank ();
93+ int size = group->size ();
94+ /* Create global array */
95+ if (rank == 0 ) {
96+ int ndim = 3 ;
97+ int64_t dims[3 ];
98+ dims[0 ] = DIM;
99+ dims[1 ] = DIM;
100+ dims[2 ] = DIM;
101+ printf (" \n Testing SCATTER on a %d x %d x %d array" ,dims[0 ],dims[1 ],dims[2 ]);
102+ printf (" running on %d processors\n " ,size);
103+ }
104+ if (rank == 0 ) {
105+ printf (" \n Testing SCATTER for ints and int64_t indices\n " );
106+ }
107+ scatter_test<int64_t ,int >();
108+ if (rank == 0 ) {
109+ printf (" \n Testing SCATTER for longs and int64_t indices\n " );
110+ }
111+ scatter_test<int64_t ,long >();
112+ if (rank == 0 ) {
113+ printf (" \n Testing SCATTER for long longs and int64_t indices\n " );
114+ }
115+ scatter_test<int64_t ,long long >();
116+ if (rank == 0 ) {
117+ printf (" \n Testing SCATTER for floats and int64_t indices\n " );
118+ }
119+ scatter_test<int64_t ,float >();
120+ if (rank == 0 ) {
121+ printf (" \n Testing SCATTER for doubles and int64_t indices\n " );
122+ }
123+ scatter_test<int64_t ,double >();
124+ if (rank == 0 ) {
125+ printf (" \n Testing SCATTER for complex floats and int64_t indices\n " );
126+ }
127+ scatter_test<int64_t ,std::complex <float > >();
128+ if (rank == 0 ) {
129+ printf (" \n Testing SCATTER for complex doubles and int64_t indices\n " );
130+ }
131+ scatter_test<int64_t ,std::complex <double > >();
132+ if (rank == 0 ) {
133+ printf (" \n Testing SCATTER for ints and int indices\n " );
134+ }
135+ scatter_test<int ,int >();
136+ if (rank == 0 ) {
137+ printf (" \n Testing SCATTER for longs and int indices\n " );
138+ }
139+ scatter_test<int ,long >();
140+ if (rank == 0 ) {
141+ printf (" \n Testing SCATTER for long longs and int indices\n " );
142+ }
143+ scatter_test<int ,long long >();
144+ if (rank == 0 ) {
145+ printf (" \n Testing SCATTER for floats and int indices\n " );
146+ }
147+ scatter_test<int ,float >();
148+ if (rank == 0 ) {
149+ printf (" \n Testing SCATTER for doubles and int indices\n " );
150+ }
151+ scatter_test<int ,double >();
152+ if (rank == 0 ) {
153+ printf (" \n Testing SCATTER for complex floats and int indices\n " );
154+ }
155+ scatter_test<int ,std::complex <float > >();
156+ if (rank == 0 ) {
157+ printf (" \n Testing SCATTER for complex doubles and int indices\n " );
158+ }
159+ scatter_test<int ,std::complex <double > >();
90160 env->finalize ();
91161 MPI_Finalize ();
92162 return 0 ;
0 commit comments