先实现一个稀疏矩阵类,这里选择通过实现随机值矩阵,但是将其中大部分值设置为0,在将原先矩阵输入之后转为稀疏矩阵保存。
使用CSR格式的稀疏矩阵。
CSR将非零值保存在一个一维数组中,value或者data, 然后将非零值对应的列索引保存在另一个一维数组中colIdx。
另外一个行指针 row_ptr, 其中row_ptr[i]表示第i行的非零元素在data中的起始位置, 当然如果第i行全为0,则row_ptr[i] 等于 row_ptr[i+1] 。
为了方便,使用vector先保存非零元素data 以及 对应索引。因为事先是不知道矩阵中非零元素个数的。
CSR稀疏矩阵类实现:
class SparseMatrix
{
public:SparseMatrix(){};SparseMatrix(u_int num_rows,u_int num_cols,float *data, u_int *col_index, u_int *row_ptr):num_rows(num_rows),num_cols(num_cols),data(data),col_index(col_index),row_ptr(row_ptr){}SparseMatrix(u_int num_rows,float *data, u_int *col_index, u_int *row_ptr):num_rows(num_rows),data(data),col_index(col_index),row_ptr(row_ptr){}SparseMatrix(float* arr, u_int rows, u_int cols){/*通过原始矩阵的一维数组,转化为CSR格式的稀疏矩阵params:arr 指向原始矩阵的一维数组rows 原始矩阵的行维度cols 原始矩阵的列维度*/num_rows = rows;num_cols = cols;row_ptr = new u_int[num_rows + 1];std::vector<u_int> vec_col_index;std::vector<u_int> vec_data;vec_data.reserve(num_cols);vec_col_index.reserve(num_cols);row_ptr[0] = 0;for(u_int i = 0; i<num_rows; i++){row_ptr[i+1] = row_ptr[i];for(u_int j = 0; j< num_cols; j++){u_int index = i*cols + j;if(abs(arr[index]) > EPSILON){vec_data.emplace_back(arr[index]);vec_col_index.emplace_back(j);row_ptr[i+1]++;}}}this->data_length= vec_data.size();data = new float[data_length];col_index = new u_int[data_length];for(int i=0;i<data_length;i++){data[i] = vec_data[i];col_index[i] = vec_col_index[i];}}~SparseMatrix(){if(data != nullptr){delete[] data;}if(col_index != nullptr){delete[] col_index;}if(row_ptr != nullptr){delete[] row_ptr;}}void printSparseMatrix(){/*通过CSR打印出原始矩阵样貌*/for(u_int row=0; row<num_rows; row++){u_int row_start = row_ptr[row];u_int row_end = row_ptr[row+1];u_int p_row = row_start;for(int i=0;i<num_cols; i++){if(i==col_index[p_row]&&(p_row<row_end)){printf("%2.2f ", data[p_row]);p_row++;}else{printf("%2.2f ", 0.0);}}printf("\n");}}u_int num_rows;u_int num_cols;float *data;u_int *col_index;u_int *row_ptr;u_int data_length;
};int main(){float arr[] = {1,7,0,0,5,0,3,9,0,2,8,0,0,0,0,6};SparseMatrix sm = SparseMatrix(arr, 4,4);sm.printSparseMatrix();return 0;
}
最后结果:
OK!!!