diff --git a/hw/dpi/float_dpi.cpp b/hw/dpi/float_dpi.cpp index 6a810555..570d6bf2 100644 --- a/hw/dpi/float_dpi.cpp +++ b/hw/dpi/float_dpi.cpp @@ -358,6 +358,15 @@ float c_D_tile[M][M]; // code assumes that svBitVecVal is basically a uint32_t static_assert(sizeof(svBitVecVal) == 4); +void clear_float_array(float* c_tile, int rows, int cols) { + for (int i = 0; i < rows; i += 1) { + for (int j = 0; j < cols; j += 1) { + int index = i * cols + j; + c_tile[index] = 0.0f; + } + } +} + void fill_float_array(const svBitVecVal* sv_tile, float* c_tile, int rows, int cols) { for (int i = 0; i < rows; i += 1) { @@ -396,6 +405,11 @@ void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, if (!enable) { return; } + clear_float_array(&c_A_tile[0][0], M, K); + clear_float_array(&c_B_tile[0][0], K, M); + clear_float_array(&c_C_tile[0][0], M, M); + clear_float_array(&c_D_tile[0][0], M, M); + // std::cout << "A: " << std::endl; fill_float_array(A_tile, &c_A_tile[0][0], M, K); // std::cout << "B: " << std::endl;