|
#include<stdio.h>#include<time.h>#include<stdlib.h>#include "mpi.h"#define N 3int main(int argc,char **argv){MPI_Init(&argc,&argv);int rank,size;MPI_Comm_rank(MPI_COMM_WORLD,&rank);MPI_Comm_size(MPI_COMM_WORLD,&size);int m=atoi(argv[1]);int n=atoi(argv[2]);int block=(m+size-1)/size*n;if(rank==size-1 && block>m*n-(size-1)*block) block=m*n-(size-1)*block;int *A=NULL;int *TA=(int*)malloc(block*sizeof(int));int **B=(int**)malloc(m*sizeof(int*));int *counts=NULL;int *ofss=NULL;int i,j;for(i=0;i<m;++i)B[i]=(int*)malloc(n*sizeof(int));if(rank==0){counts=(int*)malloc(size*sizeof(int));ofss=(int*)malloc(size*sizeof(int));for(i=0;i<size-1;++i){counts[i]=block;ofss[i]=i*block;}counts[i]=block<=m*n-i*block?block:m*n-i*block;ofss[i]=i*block;//for(i=0;i<size;++i)//printf("%d\t%d\n",counts[i],ofss[i]);//printf("constructing %d * %d matrix\n",m,n);A=(int*)malloc(m*n*sizeof(int));srand((int)time(0));for(i=0;i<m;++i){for(j=0;j<n;++j){A[i*n+j]=(int)rand()%N;B[i][j]=(int)rand()%N;}}/*printf("A:\n");for(i=0;i<m;++i){for(j=0;j<n;++j)printf("%d\t",A[i*n+j]);printf("\n");}printf("B:\n");for(i=0;i<n;++i){for(j=0;j<m;++j)printf("%d\t",B[j][i]);printf("\n");}*/}for(i=0;i<m;++i)MPI_Bcast(B[i],n,MPI_INT,0,MPI_COMM_WORLD);MPI_Scatterv(A,counts,ofss,MPI_INT,TA,block,MPI_INT,0,MPI_COMM_WORLD);/*printf("B in %d:\n",rank); for(i=0;i<m;++i) { for(j=0;j<n;++j) printf("%d\t",B[i][j]); printf("\n"); } printf("A in %d\n",rank); for(i=0;i<block;++i) printf("%d\t",TA[i]); printf("\n");*/int col=block/n;int *RETA=(int*)malloc(col*m*sizeof(int));int *RES=NULL;if(rank==0)RES=(int*)malloc(m*m*sizeof(int));int k;for(k=0;k<m;++k){for(i=0;i<col;++i){int tmp=0;for(j=0;j<n;++j)tmp+=TA[i*n+j]*B[k][j];RETA[i*m+k]=tmp;}}if(rank==0){ofss[0]=0;counts[0]=col*m;for(i=1;i<size;++i){counts[i]=counts[i]/n*m;ofss[i]=ofss[i-1]+counts[i-1];}}MPI_Gatherv(RETA,col*m,MPI_INT,RES,counts,ofss,MPI_INT,0,MPI_COMM_WORLD);if(rank==0){printf("A:\n");for(i=0;i<m;++i){printf("[\t");for(j=0;j<n;++j)printf("%d\t",A[i*n+j]);printf("]\n");}printf("B:\n");for(i=0;i<n;++i){printf("[\t");for(j=0;j<m;++j)printf("%d\t",B[j][i]);printf("]\n");}printf("=:\n");for(i=0;i<m;++i){printf("[\t");for(j=0;j<m;++j)printf("%d\t",RES[i*m+j]);printf("]\n");}free(A);free(counts);free(ofss);free(RES);}free(TA);free(RETA);for(i=0;i<m;++i)free(B[i]);free(B);MPI_Finalize();return 0;} |
|