29#ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BFLOAT16_H_
30#define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BFLOAT16_H_
33#if defined(__HIPCC_RTC__)
34 #define __HOST_DEVICE__ __device__
36 #define __HOST_DEVICE__ __host__ __device__
39#if __cplusplus < 201103L || !defined(__HIPCC__)
53#include <hip/hip_runtime.h>
55#pragma clang diagnostic push
56#pragma clang diagnostic ignored "-Wshadow"
70 : data(float_to_bfloat16(f))
74 explicit __HOST_DEVICE__
hip_bfloat16(
float f, truncate_t)
75 : data(truncate_float_to_bfloat16(f))
80 __HOST_DEVICE__
operator float()
const
86 } u = {uint32_t(data) << 16};
92 data = float_to_bfloat16(f);
96 static __HOST_DEVICE__
hip_bfloat16 round_to_bfloat16(
float f)
99 output.data = float_to_bfloat16(f);
103 static __HOST_DEVICE__
hip_bfloat16 round_to_bfloat16(
float f, truncate_t)
106 output.data = truncate_float_to_bfloat16(f);
111 static __HOST_DEVICE__ __hip_uint16_t float_to_bfloat16(
float f)
118 if(~u.int32 & 0x7f800000)
136 u.int32 += 0x7fff + ((u.int32 >> 16) & 1);
138 else if(u.int32 & 0xffff)
150 return __hip_uint16_t(u.int32 >> 16);
154 static __HOST_DEVICE__ __hip_uint16_t truncate_float_to_bfloat16(
float f)
161 return __hip_uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
164#pragma clang diagnostic pop
169} hip_bfloat16_public;
171static_assert(__hip_internal::is_standard_layout<hip_bfloat16>{},
172 "hip_bfloat16 is not a standard layout type, and thus is "
173 "incompatible with C.");
175static_assert(__hip_internal::is_trivial<hip_bfloat16>{},
176 "hip_bfloat16 is not a trivial type, and thus is "
177 "incompatible with C.");
178#if !defined(__HIPCC_RTC__)
179static_assert(
sizeof(
hip_bfloat16) ==
sizeof(hip_bfloat16_public)
180 && offsetof(
hip_bfloat16, data) == offsetof(hip_bfloat16_public, data),
181 "internal hip_bfloat16 does not match public hip_bfloat16");
183inline std::ostream& operator<<(std::ostream& os,
const hip_bfloat16& bf16)
185 return os << float(bf16);
216 return float(a) < float(b);
220 return float(a) == float(b);
279 return !(~a.data & 0x7f80) && !(a.data & 0x7f);
283 return !(~a.data & 0x7f80) && +(a.data & 0x7f);
287 return !(a.data & 0x7fff);
Struct to represent a 16 bit brain floating point number.
Definition amd_hip_bfloat16.h:47