yesidoaking 发表于 2013-1-26 12:37:16

MPI——矩阵乘法

#include<stdio.h>#include<time.h>#include<stdlib.h>#include "mpi.h"#define N 3int main(int argc,char **argv){MPI_Init(&argc,&argv);int rank,size;MPI_Comm_rank(MPI_COMM_WORLD,&rank);MPI_Comm_size(MPI_COMM_WORLD,&size);int m=atoi(argv);int n=atoi(argv);int block=(m+size-1)/size*n;if(rank==size-1 && block>m*n-(size-1)*block) block=m*n-(size-1)*block;int *A=NULL;int *TA=(int*)malloc(block*sizeof(int));int **B=(int**)malloc(m*sizeof(int*));int *counts=NULL;int *ofss=NULL;int i,j;for(i=0;i<m;++i)B=(int*)malloc(n*sizeof(int));if(rank==0){counts=(int*)malloc(size*sizeof(int));ofss=(int*)malloc(size*sizeof(int));for(i=0;i<size-1;++i){counts=block;ofss=i*block;}counts=block<=m*n-i*block?block:m*n-i*block;ofss=i*block;//for(i=0;i<size;++i)//printf("%d\t%d\n",counts,ofss);//printf("constructing %d * %d matrix\n",m,n);A=(int*)malloc(m*n*sizeof(int));srand((int)time(0));for(i=0;i<m;++i){for(j=0;j<n;++j){A=(int)rand()%N;B=(int)rand()%N;}}/*printf("A:\n");for(i=0;i<m;++i){for(j=0;j<n;++j)printf("%d\t",A);printf("\n");}printf("B:\n");for(i=0;i<n;++i){for(j=0;j<m;++j)printf("%d\t",B);printf("\n");}*/}for(i=0;i<m;++i)MPI_Bcast(B,n,MPI_INT,0,MPI_COMM_WORLD);MPI_Scatterv(A,counts,ofss,MPI_INT,TA,block,MPI_INT,0,MPI_COMM_WORLD);/*printf("B in %d:\n",rank);for(i=0;i<m;++i){for(j=0;j<n;++j)printf("%d\t",B);printf("\n");}printf("A in %d\n",rank);for(i=0;i<block;++i)printf("%d\t",TA);printf("\n");*/int col=block/n;int *RETA=(int*)malloc(col*m*sizeof(int));int *RES=NULL;if(rank==0)RES=(int*)malloc(m*m*sizeof(int));int k;for(k=0;k<m;++k){for(i=0;i<col;++i){int tmp=0;for(j=0;j<n;++j)tmp+=TA*B;RETA=tmp;}}if(rank==0){ofss=0;counts=col*m;for(i=1;i<size;++i){counts=counts/n*m;ofss=ofss+counts;}}MPI_Gatherv(RETA,col*m,MPI_INT,RES,counts,ofss,MPI_INT,0,MPI_COMM_WORLD);if(rank==0){printf("A:\n");for(i=0;i<m;++i){printf("[\t");for(j=0;j<n;++j)printf("%d\t",A);printf("]\n");}printf("B:\n");for(i=0;i<n;++i){printf("[\t");for(j=0;j<m;++j)printf("%d\t",B);printf("]\n");}printf("=:\n");for(i=0;i<m;++i){printf("[\t");for(j=0;j<m;++j)printf("%d\t",RES);printf("]\n");}free(A);free(counts);free(ofss);free(RES);}free(TA);free(RETA);for(i=0;i<m;++i)free(B);free(B);MPI_Finalize();return 0;}
页: [1]
查看完整版本: MPI——矩阵乘法