sgemm_tcore: Support fp16 input generation in host code
This commit is contained in:
4018
tests/regression/sgemm_tcore/half.hpp
Normal file
4018
tests/regression/sgemm_tcore/half.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,12 @@
|
||||
#include <string.h>
|
||||
#include <vortex.h>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
#include "common.h"
|
||||
#include "half.hpp"
|
||||
|
||||
using half_float::half;
|
||||
using half_float::half_cast;
|
||||
|
||||
#define RT_CHECK(_expr) \
|
||||
do { \
|
||||
@@ -21,8 +26,8 @@
|
||||
const char* kernel_file = "kernel.bin";
|
||||
uint32_t count = 0;
|
||||
|
||||
std::vector<float> src_a_data;
|
||||
std::vector<float> src_b_data;
|
||||
template <typename T> std::vector<T> src_a_data;
|
||||
template <typename T> std::vector<T> src_b_data;
|
||||
std::vector<float> ref_data;
|
||||
|
||||
vx_device_h device = nullptr;
|
||||
@@ -65,28 +70,45 @@ void cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void generate_source_matrix(uint32_t dim_m, uint32_t dim_n, uint32_t dim_k) {
|
||||
src_a_data.resize(dim_m * dim_k);
|
||||
src_b_data.resize(dim_k * dim_n);
|
||||
static_assert(std::is_same_v<half, T> || std::is_same_v<float, T>,
|
||||
"unsupported floating point datatype");
|
||||
|
||||
for (uint32_t i = 0; i < src_a_data.size(); ++i) {
|
||||
src_a_data[i] = static_cast<float>(i);
|
||||
std::cout << "A: " << i << ": value=" << src_a_data[i] << std::endl;
|
||||
src_a_data<T>.resize(dim_m * dim_k);
|
||||
src_b_data<T>.resize(dim_k * dim_n);
|
||||
|
||||
for (uint32_t i = 0; i < src_a_data<T>.size(); ++i) {
|
||||
if constexpr (std::is_same_v<half, T>) {
|
||||
src_a_data<T>[i] = half_cast<half>(static_cast<float>(i));
|
||||
} else if (std::is_same_v<float, T>) {
|
||||
src_a_data<T>[i] = static_cast<float>(i);
|
||||
}
|
||||
std::cout << "A: " << i << ": value=" << src_a_data<T>[i] << std::endl;
|
||||
}
|
||||
for (uint32_t i = 0; i < src_b_data.size(); ++i) {
|
||||
src_b_data[i] = static_cast<float>(i);
|
||||
std::cout << "B: " << i << ": value=" << src_b_data[i] << std::endl;
|
||||
for (uint32_t i = 0; i < src_b_data<T>.size(); ++i) {
|
||||
if constexpr (std::is_same_v<half, T>) {
|
||||
src_b_data<T>[i] = half_cast<half>(static_cast<float>(i));
|
||||
} else if (std::is_same_v<float, T>) {
|
||||
src_b_data<T>[i] = static_cast<float>(i);
|
||||
}
|
||||
std::cout << "B: " << i << ": value=" << src_b_data<T>[i] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void generate_reference_matmul(uint32_t dim_m, uint32_t dim_n, uint32_t dim_k) {
|
||||
static_assert(std::is_same_v<half, T> || std::is_same_v<float, T>,
|
||||
"unsupported floating point datatype");
|
||||
|
||||
ref_data.resize(dim_m * dim_n);
|
||||
|
||||
for (uint32_t i = 0; i < dim_m; ++i) {
|
||||
for (uint32_t j = 0; j < dim_n; ++j) {
|
||||
float ref = 0.0f;
|
||||
for (uint32_t k = 0; k < dim_k; ++k) {
|
||||
ref += src_a_data[dim_k * i + k] * src_b_data[dim_n * k + j];
|
||||
ref += static_cast<float>(src_a_data<T>[dim_k * i + k]) *
|
||||
static_cast<float>(src_b_data<T>[dim_n * k + j]);
|
||||
}
|
||||
ref_data.at(dim_n * i + j) = ref;
|
||||
}
|
||||
@@ -151,8 +173,9 @@ int main(int argc, char *argv[]) {
|
||||
uint32_t dim_n = 64;
|
||||
uint32_t dim_k = 64;
|
||||
|
||||
generate_source_matrix(dim_m, dim_n, dim_k);
|
||||
generate_reference_matmul(dim_m, dim_n, dim_k);
|
||||
using float_type = half;
|
||||
generate_source_matrix<float_type>(dim_m, dim_n, dim_k);
|
||||
generate_reference_matmul<float_type>(dim_m, dim_n, dim_k);
|
||||
|
||||
std::cout << "write reference output" << std::endl;
|
||||
std::ofstream ref_file("reference.c.bin", std::ios::binary | std::ios::out);
|
||||
@@ -163,9 +186,9 @@ int main(int argc, char *argv[]) {
|
||||
ref_file.write(reinterpret_cast<char *>(ref_data.data()), ref_data.size() * sizeof(ref_data[0]));
|
||||
ref_file.close();
|
||||
|
||||
uint32_t src_a_buf_size = src_a_data.size() * sizeof(src_a_data[0]);
|
||||
uint32_t src_b_buf_size = src_b_data.size() * sizeof(src_b_data[0]);
|
||||
uint32_t dst_buf_size = ref_data.size() * sizeof(src_a_data[0]);
|
||||
uint32_t src_a_buf_size = src_a_data<float_type>.size() * sizeof(src_a_data<float_type>[0]);
|
||||
uint32_t src_b_buf_size = src_b_data<float_type>.size() * sizeof(src_b_data<float_type>[0]);
|
||||
uint32_t dst_buf_size = ref_data.size() * sizeof(src_a_data<float_type>[0]);
|
||||
|
||||
std::cout << "buffer size: " << dst_buf_size << " bytes" << std::endl;
|
||||
|
||||
@@ -225,7 +248,8 @@ int main(int argc, char *argv[]) {
|
||||
{
|
||||
{
|
||||
auto buf_ptr = staging_buf.data();
|
||||
memcpy(buf_ptr, src_a_data.data(), src_a_data.size() * sizeof(float));
|
||||
memcpy(buf_ptr, src_a_data<float_type>.data(),
|
||||
src_a_data<float_type>.size() * sizeof(float_type));
|
||||
RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_a, staging_buf.data(),
|
||||
src_a_buf_size));
|
||||
|
||||
@@ -242,7 +266,8 @@ int main(int argc, char *argv[]) {
|
||||
}
|
||||
{
|
||||
auto buf_ptr = staging_buf.data();
|
||||
memcpy(buf_ptr, src_b_data.data(), src_b_data.size() * sizeof(float));
|
||||
memcpy(buf_ptr, src_b_data<float_type>.data(),
|
||||
src_b_data<float_type>.size() * sizeof(float_type));
|
||||
RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_b, staging_buf.data(),
|
||||
src_b_buf_size));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user