在 ONNX 模型中使用 float16 和 float8 格式

MetaQuotes | 8 四月, 2024

目录

随着机器学习和人工智能技术的进步,人们越来越需要优化处理模型的过程。模型操作的效率直接取决于用于表示它们的数据格式。近年来,出现了几种新的数据类型,专门为使用深度学习模型而设计。

在本文中,我们将重点介绍两种新的数据格式float16和float8,它们正开始在现代ONNX模型中积极使用。这些格式代表了更精确但资源密集型浮点数据格式的替代选项。它们提供了性能和准确性之间的最佳平衡,使其对各种机器学习任务特别有吸引力。我们将探讨float16和float8格式的关键特性和优势,并介绍将它们转换为标准float和double格式的函数。

这将帮助开发人员和研究人员更好地了解如何在他们的项目和模型中有效地使用这些格式。例如,我们将探讨用于图像质量增强的ESRGAN ONNX模型的操作。


1. ONNX模型中使用的新数据类型

为了加快计算速度,一些模型使用精度较低的数据类型,如Float16甚至Float8。

添加了对这些新数据类型的支持,以便与MQL5语言中的ONNX模型一起使用,从而允许操作8位和16位浮点表示。

脚本输出ENUM_ONNX_DATA_TYPE枚举的完整元素列表。

//+------------------------------------------------------------------+
//|                                              ONNX_Data_Types.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
void OnStart()
  {
//---
   for(int i=0; i<21; i++)
      PrintFormat("%2d %s",i,EnumToString(ENUM_ONNX_DATA_TYPE(i)));
  }

输出:

 0: ONNX_DATA_TYPE_UNDEFINED
 1: ONNX_DATA_TYPE_FLOAT
 2: ONNX_DATA_TYPE_UINT8
 3: ONNX_DATA_TYPE_INT8
 4: ONNX_DATA_TYPE_UINT16
 5: ONNX_DATA_TYPE_INT16
 6: ONNX_DATA_TYPE_INT32
 7: ONNX_DATA_TYPE_INT64
 8: ONNX_DATA_TYPE_STRING
 9: ONNX_DATA_TYPE_BOOL
10: ONNX_DATA_TYPE_FLOAT16
11: ONNX_DATA_TYPE_DOUBLE
12: ONNX_DATA_TYPE_UINT32
13: ONNX_DATA_TYPE_UINT64
14: ONNX_DATA_TYPE_COMPLEX64
15: ONNX_DATA_TYPE_COMPLEX128
16: ONNX_DATA_TYPE_BFLOAT16
17: ONNX_DATA_TYPE_FLOAT8E4M3FN
18: ONNX_DATA_TYPE_FLOAT8E4M3FNUZ
19: ONNX_DATA_TYPE_FLOAT8E5M2
20: ONNX_DATA_TYPE_FLOAT8E5M2FNUZ

因此,现在可以执行使用此类数据的ONNX模型。

此外,在MQL5中,还添加了用于数据转换的附加函数:

bool ArrayToFP16(ushort &dst_array[],const float &src_array[],ENUM_FLOAT16_FORMAT fmt);
bool ArrayToFP16(ushort &dst_array[],const double &src_array[],ENUM_FLOAT16_FORMAT fmt);
bool ArrayToFP8(uchar &dst_array[],const float &src_array[],ENUM_FLOAT8_FORMAT fmt);
bool ArrayToFP8(uchar &dst_array[],const double &src_array[],ENUM_FLOAT8_FORMAT fmt);

bool ArrayFromFP16(float &dst_array[],const ushort &src_array[],ENUM_FLOAT16_FORMAT fmt);
bool ArrayFromFP16(double &dst_array[],const ushort &src_array[],ENUM_FLOAT16_FORMAT fmt);
bool ArrayFromFP8(float &dst_array[],const uchar &src_array[],ENUM_FLOAT8_FORMAT fmt);
bool ArrayFromFP8(double &dst_array[],const uchar &src_array[],ENUM_FLOAT8_FORMAT fmt);

由于16位和8位的浮点格式可能不同,转换函数中的“fmt”参数必须指定需要处理的数字格式。

对于16位版本,将使用新的ENUM_FLOAT16_FORMAT枚举,该枚举当前具有以下值:

对于8位版本,将使用新的ENUM_FLOAT8_FORMAT枚举,该枚举当前具有以下值:


1.1. FP16格式

FLOAT16和BFLOAT16格式是用于表示浮点数的数据类型。

FLOAT16,也称为“半精度浮点”格式,使用16位表示浮点数。这种格式提供了精度和计算效率之间的平衡。FLOAT16广泛用于深度学习和神经网络,在处理大量数据时需要高性能。这种格式允许通过减少数字的大小来加速计算,这在图形处理单元(GPU)上训练深度神经网络时尤为重要。

BFLOAT16(或 脑浮点16)也使用16位,但它与FLOAT16的不同之处在于表示数字的方式。在这种格式中,8个比特被分配用于表示指数,其余7个比特用于表示尾数。这种格式是为用于深度学习和人工智能而开发的,尤其是在谷歌张量处理单元(Tensor Processing Unit,TPU)处理器中。BFLOAT16在训练神经网络时提供了良好的性能,并且可以有效地用于加速计算

这两种格式都有其优点和局限性。FLOAT16提供了更高的精度,但需要更多的资源用于存储和计算。另一方面,BFLOAT16在处理数据时提供了更高的性能和效率,但可能不那么精确。


图 1. 浮点数字FLOAT16和BFLOAT16的位表示格式

图 1. 浮点数字FLOAT16和BFLOAT16的位表示格式


表 1. FLOAT16格式的浮点数字


1.1.1. FLOAT16 ONNX Cast 运算符的执行测试

作为示例,让我们考虑将FLOAT16类型的数据转换为float和double类型的任务。

带有Cast操作的ONNX模型:

图 2. 模型 test_cast_FLOAT16_to_DUBLE.onnx 的输入和输出参数

图 2. 模型 test_cast_FLOAT16_to_DUBLE.onnx 的输入和输出参数


图 3. 模型 test_cast_FLOAT16_to_FLOAT.onx 的输入和输出参数

图 3. 模型 test_cast_FLOAT16_to_FLOAT.onx 的输入和输出参数


从ONNX模型的属性描述中可以看出,输入需要ONNX_data_type_FLOAT16类型的数据,并且模型将返回ONNX_data_type_FLOAT格式的输出数据。

要转换这些值,我们将使用带有FLOAT_FP16参数的函数ArrayToFP16()和ArrayFromFP16()。

示例:

//+------------------------------------------------------------------+
//|                                              TestCastFloat16.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"

#resource "models\\test_cast_FLOAT16_to_DOUBLE.onnx" as const uchar ExtModel1[];
#resource "models\\test_cast_FLOAT16_to_FLOAT.onnx" as const uchar ExtModel2[];

//+------------------------------------------------------------------+
//| union for data conversion                                        |
//+------------------------------------------------------------------+
template<typename T>
union U
  {
   uchar uc[sizeof(T)];
   T value;
  };
//+------------------------------------------------------------------+
//| ArrayToString                                                    |
//+------------------------------------------------------------------+
template<typename T>
string ArrayToString(const T &data[],uint length=16)
  {
   string res;

   for(uint n=0; n<MathMin(length,data.Size()); n++)
      res+="," + StringFormat("%.2x",data[n]);

   StringSetCharacter(res,0,'[');
   return res+"]";
  }

//+------------------------------------------------------------------+
//| PatchONNXModel                                                   |
//+------------------------------------------------------------------+
void PatchONNXModel(const uchar &original_model[],uchar &patched_model[])
  {
   ArrayCopy(patched_model,original_model,0,0,WHOLE_ARRAY);
//--- special ONNX model patch(IR=9,Opset=20)
   patched_model[1]=0x09;
   patched_model[ArraySize(patched_model)-1]=0x14;
  }
//+------------------------------------------------------------------+
//| CreateModel                                                      |
//+------------------------------------------------------------------+
bool CreateModel(long &model_handle,const uchar &model[])
  {
   model_handle=INVALID_HANDLE;
   ulong flags=ONNX_DEFAULT;
//ulong flags=ONNX_DEBUG_LOGS;
//---
   model_handle=OnnxCreateFromBuffer(model,flags);
   if(model_handle==INVALID_HANDLE)
      return(false);
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| PrepareShapes                                                    |
//+------------------------------------------------------------------+
bool PrepareShapes(long model_handle)
  {
   ulong input_shape1[]= {3,4};
   if(!OnnxSetInputShape(model_handle,0,input_shape1))
     {
      PrintFormat("error in OnnxSetInputShape for input1. error code=%d",GetLastError());
      //--
      OnnxRelease(model_handle);
      return(false);
     }
//---
   ulong output_shape[]= {3,4};
   if(!OnnxSetOutputShape(model_handle,0,output_shape))
     {
      PrintFormat("error in OnnxSetOutputShape for output. error code=%d",GetLastError());
      //--
      OnnxRelease(model_handle);
      return(false);
     }
//---
   return(true);
  }

//+------------------------------------------------------------------+
//| RunCastFloat16ToDouble                                           |
//+------------------------------------------------------------------+
bool RunCastFloat16ToDouble(long model_handle)
  {
   PrintFormat("test=%s",__FUNCTION__);

   double test_data[12]= {1,2,3,4,5,6,7,8,9,10,11,12};
   ushort data_uint16[12];
   if(!ArrayToFP16(data_uint16,test_data,FLOAT_FP16))
     {
      Print("error in ArrayToFP16. error code=",GetLastError());
      return(false);
     }
   Print("test array:");
   ArrayPrint(test_data);
   Print("ArrayToFP16:");
   ArrayPrint(data_uint16);

   U<ushort> input_float16_values[3*4];
   U<double> output_double_values[3*4];

   float test_data_float[];
   if(!ArrayFromFP16(test_data_float,data_uint16,FLOAT_FP16))
     {
      Print("error in ArrayFromFP16. error code=",GetLastError());
      return(false);
     }

   for(int i=0; i<12; i++)
     {
      input_float16_values[i].value=data_uint16[i];
      PrintFormat("%d input value =%f  Hex float16 = %s  ushort value=%d",i,test_data_float[i],ArrayToString(input_float16_values[i].uc),input_float16_values[i].value);
     }

   Print("ONNX input array:");
   ArrayPrint(input_float16_values);

   bool res=OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float16_values,output_double_values);
   if(!res)
     {
      PrintFormat("error in OnnxRun. error code=%d",GetLastError());
      return(false);
     }

   Print("ONNX output array:");
   ArrayPrint(output_double_values);
//---
   double sum_error=0.0;
   for(int i=0; i<12; i++)
     {
      double delta=test_data[i]-output_double_values[i].value;
      sum_error+=MathAbs(delta);
      PrintFormat("%d output double %f = %s  difference=%f",i,output_double_values[i].value,ArrayToString(output_double_values[i].uc),delta);
     }
//---
   PrintFormat("test=%s   sum_error=%f",__FUNCTION__,sum_error);
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| RunCastFloat16ToFloat                                            |
//+------------------------------------------------------------------+
bool RunCastFloat16ToFloat(long model_handle)
  {
   PrintFormat("test=%s",__FUNCTION__);

   double test_data[12]= {1,2,3,4,5,6,7,8,9,10,11,12};
   ushort data_uint16[12];
   if(!ArrayToFP16(data_uint16,test_data,FLOAT_FP16))
     {
      Print("error in ArrayToFP16. error code=",GetLastError());
      return(false);
     }
   Print("test array:");
   ArrayPrint(test_data);
   Print("ArrayToFP16:");
   ArrayPrint(data_uint16);

   U<ushort> input_float16_values[3*4];
   U<float>  output_float_values[3*4];

   float test_data_float[];
   if(!ArrayFromFP16(test_data_float,data_uint16,FLOAT_FP16))
     {
      Print("error in ArrayFromFP16. error code=",GetLastError());
      return(false);
     }

   for(int i=0; i<12; i++)
     {
      input_float16_values[i].value=data_uint16[i];
      PrintFormat("%d input value =%f  Hex float16 = %s  ushort value=%d",i,test_data_float[i],ArrayToString(input_float16_values[i].uc),input_float16_values[i].value);
     }

   Print("ONNX input array:");
   ArrayPrint(input_float16_values);

   bool res=OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float16_values,output_float_values);
   if(!res)
     {
      PrintFormat("error in OnnxRun. error code=%d",GetLastError());
      return(false);
     }

   Print("ONNX output array:");
   ArrayPrint(output_float_values);
//---
   double sum_error=0.0;
   for(int i=0; i<12; i++)
     {
      double delta=test_data[i]-(double)output_float_values[i].value;
      sum_error+=MathAbs(delta);
      PrintFormat("%d output float %f = %s difference=%f",i,output_float_values[i].value,ArrayToString(output_float_values[i].uc),delta);
     }
//---
   PrintFormat("test=%s   sum_error=%f",__FUNCTION__,sum_error);
//---
   return(true);
  }

//+------------------------------------------------------------------+
//| TestCastFloat16ToFloat                                           |
//+------------------------------------------------------------------+
bool TestCastFloat16ToFloat(const uchar &res_model[])
  {
   uchar model[];
   PatchONNXModel(res_model,model);
//--- get model handle
   long model_handle=INVALID_HANDLE;
//--- get model handle
   if(!CreateModel(model_handle,model))
      return(false);
//--- prepare input and output shapes
   if(!PrepareShapes(model_handle))
      return(false);
//--- run ONNX model
   if(!RunCastFloat16ToFloat(model_handle))
      return(false);
//--- release model handle
   OnnxRelease(model_handle);
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| TestCastFloat16ToDouble                                          |
//+------------------------------------------------------------------+
bool TestCastFloat16ToDouble(const uchar &res_model[])
  {
   uchar model[];
   PatchONNXModel(res_model,model);
//---
   long model_handle=INVALID_HANDLE;
//--- get model handle
   if(!CreateModel(model_handle,model))
      return(false);
//--- prepare input and output shapes
   if(!PrepareShapes(model_handle))
      return(false);
//--- run ONNX model
   if(!RunCastFloat16ToDouble(model_handle))
      return(false);
//--- release model handle
   OnnxRelease(model_handle);
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
int OnStart(void)
  {
   if(!TestCastFloat16ToDouble(ExtModel1))
      return 1;

   if(!TestCastFloat16ToFloat(ExtModel2))
      return 1;
//---
   return 0;
  }
//+------------------------------------------------------------------+

输出:

TestCastFloat16 (EURUSD,H1)     test=RunCastFloat16ToDouble
TestCastFloat16 (EURUSD,H1)     test array:
TestCastFloat16 (EURUSD,H1)      1.00000  2.00000  3.00000  4.00000  5.00000  6.00000  7.00000  8.00000  9.00000 10.00000 11.00000 12.00000
TestCastFloat16 (EURUSD,H1)     ArrayToFP16:
TestCastFloat16 (EURUSD,H1)     15360 16384 16896 17408 17664 17920 18176 18432 18560 18688 18816 18944
TestCastFloat16 (EURUSD,H1)     0 input value =1.000000  Hex float16 = [00,3c]  ushort value=15360
TestCastFloat16 (EURUSD,H1)     1 input value =2.000000  Hex float16 = [00,40]  ushort value=16384
TestCastFloat16 (EURUSD,H1)     2 input value =3.000000  Hex float16 = [00,42]  ushort value=16896
TestCastFloat16 (EURUSD,H1)     3 input value =4.000000  Hex float16 = [00,44]  ushort value=17408
TestCastFloat16 (EURUSD,H1)     4 input value =5.000000  Hex float16 = [00,45]  ushort value=17664
TestCastFloat16 (EURUSD,H1)     5 input value =6.000000  Hex float16 = [00,46]  ushort value=17920
TestCastFloat16 (EURUSD,H1)     6 input value =7.000000  Hex float16 = [00,47]  ushort value=18176
TestCastFloat16 (EURUSD,H1)     7 input value =8.000000  Hex float16 = [00,48]  ushort value=18432
TestCastFloat16 (EURUSD,H1)     8 input value =9.000000  Hex float16 = [80,48]  ushort value=18560
TestCastFloat16 (EURUSD,H1)     9 input value =10.000000  Hex float16 = [00,49]  ushort value=18688
TestCastFloat16 (EURUSD,H1)     10 input value =11.000000  Hex float16 = [80,49]  ushort value=18816
TestCastFloat16 (EURUSD,H1)     11 input value =12.000000  Hex float16 = [00,4a]  ushort value=18944
TestCastFloat16 (EURUSD,H1)     ONNX input array:
TestCastFloat16 (EURUSD,H1)          [uc] [value]
TestCastFloat16 (EURUSD,H1)     [ 0]  ...   15360
TestCastFloat16 (EURUSD,H1)     [ 1]  ...   16384
TestCastFloat16 (EURUSD,H1)     [ 2]  ...   16896
TestCastFloat16 (EURUSD,H1)     [ 3]  ...   17408
TestCastFloat16 (EURUSD,H1)     [ 4]  ...   17664
TestCastFloat16 (EURUSD,H1)     [ 5]  ...   17920
TestCastFloat16 (EURUSD,H1)     [ 6]  ...   18176
TestCastFloat16 (EURUSD,H1)     [ 7]  ...   18432
TestCastFloat16 (EURUSD,H1)     [ 8]  ...   18560
TestCastFloat16 (EURUSD,H1)     [ 9]  ...   18688
TestCastFloat16 (EURUSD,H1)     [10]  ...   18816
TestCastFloat16 (EURUSD,H1)     [11]  ...   18944
TestCastFloat16 (EURUSD,H1)     ONNX output array:
TestCastFloat16 (EURUSD,H1)          [uc]  [value]
TestCastFloat16 (EURUSD,H1)     [ 0]  ...  1.00000
TestCastFloat16 (EURUSD,H1)     [ 1]  ...  2.00000
TestCastFloat16 (EURUSD,H1)     [ 2]  ...  3.00000
TestCastFloat16 (EURUSD,H1)     [ 3]  ...  4.00000
TestCastFloat16 (EURUSD,H1)     [ 4]  ...  5.00000
TestCastFloat16 (EURUSD,H1)     [ 5]  ...  6.00000
TestCastFloat16 (EURUSD,H1)     [ 6]  ...  7.00000
TestCastFloat16 (EURUSD,H1)     [ 7]  ...  8.00000
TestCastFloat16 (EURUSD,H1)     [ 8]  ...  9.00000
TestCastFloat16 (EURUSD,H1)     [ 9]  ... 10.00000
TestCastFloat16 (EURUSD,H1)     [10]  ... 11.00000
TestCastFloat16 (EURUSD,H1)     [11]  ... 12.00000
TestCastFloat16 (EURUSD,H1)     0 output double 1.000000 = [00,00,00,00,00,00,f0,3f]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     1 output double 2.000000 = [00,00,00,00,00,00,00,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     2 output double 3.000000 = [00,00,00,00,00,00,08,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     3 output double 4.000000 = [00,00,00,00,00,00,10,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     4 output double 5.000000 = [00,00,00,00,00,00,14,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     5 output double 6.000000 = [00,00,00,00,00,00,18,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     6 output double 7.000000 = [00,00,00,00,00,00,1c,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     7 output double 8.000000 = [00,00,00,00,00,00,20,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     8 output double 9.000000 = [00,00,00,00,00,00,22,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     9 output double 10.000000 = [00,00,00,00,00,00,24,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     10 output double 11.000000 = [00,00,00,00,00,00,26,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     11 output double 12.000000 = [00,00,00,00,00,00,28,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     test=RunCastFloat16ToDouble   sum_error=0.000000
TestCastFloat16 (EURUSD,H1)     test=RunCastFloat16ToFloat
TestCastFloat16 (EURUSD,H1)     test array:
TestCastFloat16 (EURUSD,H1)      1.00000  2.00000  3.00000  4.00000  5.00000  6.00000  7.00000  8.00000  9.00000 10.00000 11.00000 12.00000
TestCastFloat16 (EURUSD,H1)     ArrayToFP16:
TestCastFloat16 (EURUSD,H1)     15360 16384 16896 17408 17664 17920 18176 18432 18560 18688 18816 18944
TestCastFloat16 (EURUSD,H1)     0 input value =1.000000  Hex float16 = [00,3c]  ushort value=15360
TestCastFloat16 (EURUSD,H1)     1 input value =2.000000  Hex float16 = [00,40]  ushort value=16384
TestCastFloat16 (EURUSD,H1)     2 input value =3.000000  Hex float16 = [00,42]  ushort value=16896
TestCastFloat16 (EURUSD,H1)     3 input value =4.000000  Hex float16 = [00,44]  ushort value=17408
TestCastFloat16 (EURUSD,H1)     4 input value =5.000000  Hex float16 = [00,45]  ushort value=17664
TestCastFloat16 (EURUSD,H1)     5 input value =6.000000  Hex float16 = [00,46]  ushort value=17920
TestCastFloat16 (EURUSD,H1)     6 input value =7.000000  Hex float16 = [00,47]  ushort value=18176
TestCastFloat16 (EURUSD,H1)     7 input value =8.000000  Hex float16 = [00,48]  ushort value=18432
TestCastFloat16 (EURUSD,H1)     8 input value =9.000000  Hex float16 = [80,48]  ushort value=18560
TestCastFloat16 (EURUSD,H1)     9 input value =10.000000  Hex float16 = [00,49]  ushort value=18688
TestCastFloat16 (EURUSD,H1)     10 input value =11.000000  Hex float16 = [80,49]  ushort value=18816
TestCastFloat16 (EURUSD,H1)     11 input value =12.000000  Hex float16 = [00,4a]  ushort value=18944
TestCastFloat16 (EURUSD,H1)     ONNX input array:
TestCastFloat16 (EURUSD,H1)          [uc] [value]
TestCastFloat16 (EURUSD,H1)     [ 0]  ...   15360
TestCastFloat16 (EURUSD,H1)     [ 1]  ...   16384
TestCastFloat16 (EURUSD,H1)     [ 2]  ...   16896
TestCastFloat16 (EURUSD,H1)     [ 3]  ...   17408
TestCastFloat16 (EURUSD,H1)     [ 4]  ...   17664
TestCastFloat16 (EURUSD,H1)     [ 5]  ...   17920
TestCastFloat16 (EURUSD,H1)     [ 6]  ...   18176
TestCastFloat16 (EURUSD,H1)     [ 7]  ...   18432
TestCastFloat16 (EURUSD,H1)     [ 8]  ...   18560
TestCastFloat16 (EURUSD,H1)     [ 9]  ...   18688
TestCastFloat16 (EURUSD,H1)     [10]  ...   18816
TestCastFloat16 (EURUSD,H1)     [11]  ...   18944
TestCastFloat16 (EURUSD,H1)     ONNX output array:
TestCastFloat16 (EURUSD,H1)          [uc]  [value]
TestCastFloat16 (EURUSD,H1)     [ 0]  ...  1.00000
TestCastFloat16 (EURUSD,H1)     [ 1]  ...  2.00000
TestCastFloat16 (EURUSD,H1)     [ 2]  ...  3.00000
TestCastFloat16 (EURUSD,H1)     [ 3]  ...  4.00000
TestCastFloat16 (EURUSD,H1)     [ 4]  ...  5.00000
TestCastFloat16 (EURUSD,H1)     [ 5]  ...  6.00000
TestCastFloat16 (EURUSD,H1)     [ 6]  ...  7.00000
TestCastFloat16 (EURUSD,H1)     [ 7]  ...  8.00000
TestCastFloat16 (EURUSD,H1)     [ 8]  ...  9.00000
TestCastFloat16 (EURUSD,H1)     [ 9]  ... 10.00000
TestCastFloat16 (EURUSD,H1)     [10]  ... 11.00000
TestCastFloat16 (EURUSD,H1)     [11]  ... 12.00000
TestCastFloat16 (EURUSD,H1)     0 output float 1.000000 = [00,00,80,3f] difference=0.000000
TestCastFloat16 (EURUSD,H1)     1 output float 2.000000 = [00,00,00,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     2 output float 3.000000 = [00,00,40,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     3 output float 4.000000 = [00,00,80,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     4 output float 5.000000 = [00,00,a0,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     5 output float 6.000000 = [00,00,c0,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     6 output float 7.000000 = [00,00,e0,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     7 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastFloat16 (EURUSD,H1)     8 output float 9.000000 = [00,00,10,41] difference=0.000000
TestCastFloat16 (EURUSD,H1)     9 output float 10.000000 = [00,00,20,41] difference=0.000000
TestCastFloat16 (EURUSD,H1)     10 output float 11.000000 = [00,00,30,41] difference=0.000000
TestCastFloat16 (EURUSD,H1)     11 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat16 (EURUSD,H1)     test=RunCastFloat16ToFloat   sum_error=0.000000


1.1.2. BFLOAT16 ONNX Cast 运算符的执行测试

此示例检查从BFLOAT16到float的转换。

带有Cast操作的ONNX模型:


图 4. 模型 test_cast_BFLOAT16_to_FLOAT.onx的输入和输出参数

图 4. 模型 test_cast_BFLOAT16_to_FLOAT.onx的输入和输出参数

ONNX_data_type_BFLOAT16类型的输入数据是必需的,模型将以ONNX_data_type_FLOAT的格式返回输出数据。

要转换这些值,我们将使用函数ArrayToFP16()和函数ArrayFromFP16(()以及参数BFLOAT_FP16。

//+------------------------------------------------------------------+
//|                                             TestCastBFloat16.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"

#resource "models\\test_cast_BFLOAT16_to_FLOAT.onnx" as const uchar ExtModel1[];

//+------------------------------------------------------------------+
//| union for data conversion                                        |
//+------------------------------------------------------------------+
template<typename T>
union U
  {
   uchar uc[sizeof(T)];
   T value;
  };
//+------------------------------------------------------------------+
//| ArrayToString                                                    |
//+------------------------------------------------------------------+
template<typename T>
string ArrayToString(const T &data[],uint length=16)
  {
   string res;

   for(uint n=0; n<MathMin(length,data.Size()); n++)
      res+="," + StringFormat("%.2x",data[n]);

   StringSetCharacter(res,0,'[');
   return res+"]";
  }

//+------------------------------------------------------------------+
//| PatchONNXModel                                                   |
//+------------------------------------------------------------------+
void PatchONNXModel(const uchar &original_model[],uchar &patched_model[])
  {
   ArrayCopy(patched_model,original_model,0,0,WHOLE_ARRAY);
//--- special ONNX model patch(IR=9,Opset=20)
   patched_model[1]=0x09;
   patched_model[ArraySize(patched_model)-1]=0x14;
  }
//+------------------------------------------------------------------+
//| CreateModel                                                      |
//+------------------------------------------------------------------+
bool CreateModel(long &model_handle,const uchar &model[])
  {
   model_handle=INVALID_HANDLE;
   ulong flags=ONNX_DEFAULT;
//ulong flags=ONNX_DEBUG_LOGS;
//---
   model_handle=OnnxCreateFromBuffer(model,flags);
   if(model_handle==INVALID_HANDLE)
      return(false);
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| PrepareShapes                                                    |
//+------------------------------------------------------------------+
bool PrepareShapes(long model_handle)
  {
   ulong input_shape1[]= {3,4};
   if(!OnnxSetInputShape(model_handle,0,input_shape1))
     {
      PrintFormat("error in OnnxSetInputShape for input1. error code=%d",GetLastError());
      //--
      OnnxRelease(model_handle);
      return(false);
     }
//---
   ulong output_shape[]= {3,4};
   if(!OnnxSetOutputShape(model_handle,0,output_shape))
     {
      PrintFormat("error in OnnxSetOutputShape for output. error code=%d",GetLastError());
      //--
      OnnxRelease(model_handle);
      return(false);
     }
//---
   return(true);
  }

//+------------------------------------------------------------------+
//| RunCastBFloat16ToFloat                                           |
//+------------------------------------------------------------------+
bool RunCastBFloat16ToFloat(long model_handle)
  {
   PrintFormat("test=%s",__FUNCTION__);

   double test_data[12]= {1,2,3,4,5,6,7,8,9,10,11,12};
   ushort data_uint16[12];
   if(!ArrayToFP16(data_uint16,test_data,FLOAT_BFP16))
     {
      Print("error in ArrayToFP16. error code=",GetLastError());
      return(false);
     }
   Print("test array:");
   ArrayPrint(test_data);
   Print("ArrayToFP16:");
   ArrayPrint(data_uint16);

   U<ushort> input_float16_values[3*4];
   U<float>  output_float_values[3*4];

   float test_data_float[];
   if(!ArrayFromFP16(test_data_float,data_uint16,FLOAT_BFP16))
     {
      Print("error in ArrayFromFP16. error code=",GetLastError());
      return(false);
     }

   for(int i=0; i<12; i++)
     {
      input_float16_values[i].value=data_uint16[i];
      PrintFormat("%d input value =%f  Hex float16 = %s  ushort value=%d",i,test_data_float[i],ArrayToString(input_float16_values[i].uc),input_float16_values[i].value);
     }

   Print("ONNX input array:");
   ArrayPrint(input_float16_values);

   bool res=OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float16_values,output_float_values);
   if(!res)
     {
      PrintFormat("error in OnnxRun. error code=%d",GetLastError());
      return(false);
     }

   Print("ONNX output array:");
   ArrayPrint(output_float_values);
//---
   double sum_error=0.0;
   for(int i=0; i<12; i++)
     {
      double delta=test_data[i]-(double)output_float_values[i].value;
      sum_error+=MathAbs(delta);
      PrintFormat("%d output float %f = %s difference=%f",i,output_float_values[i].value,ArrayToString(output_float_values[i].uc),delta);
     }
//---
   PrintFormat("test=%s   sum_error=%f",__FUNCTION__,sum_error);
//---
   return(true);
  }

//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
int OnStart(void)
  {
   uchar model[];
   PatchONNXModel(ExtModel1,model);
//--- get model handle
   long model_handle=INVALID_HANDLE;
//--- get model handle
   if(!CreateModel(model_handle,model))
      return 1;
//--- prepare input and output shapes
   if(!PrepareShapes(model_handle))
      return 1;
//--- run ONNX model
   if(!RunCastBFloat16ToFloat(model_handle))
      return 1;
//--- release model handle
   OnnxRelease(model_handle);
//---
   return 0;
  }
//+------------------------------------------------------------------+
输出:
TestCastBFloat16 (EURUSD,H1)    test=RunCastBFloat16ToFloat
TestCastBFloat16 (EURUSD,H1)    test array:
TestCastBFloat16 (EURUSD,H1)     1.00000  2.00000  3.00000  4.00000  5.00000  6.00000  7.00000  8.00000  9.00000 10.00000 11.00000 12.00000
TestCastBFloat16 (EURUSD,H1)    ArrayToFP16:
TestCastBFloat16 (EURUSD,H1)    16256 16384 16448 16512 16544 16576 16608 16640 16656 16672 16688 16704
TestCastBFloat16 (EURUSD,H1)    0 input value =1.000000  Hex float16 = [80,3f]  ushort value=16256
TestCastBFloat16 (EURUSD,H1)    1 input value =2.000000  Hex float16 = [00,40]  ushort value=16384
TestCastBFloat16 (EURUSD,H1)    2 input value =3.000000  Hex float16 = [40,40]  ushort value=16448
TestCastBFloat16 (EURUSD,H1)    3 input value =4.000000  Hex float16 = [80,40]  ushort value=16512
TestCastBFloat16 (EURUSD,H1)    4 input value =5.000000  Hex float16 = [a0,40]  ushort value=16544
TestCastBFloat16 (EURUSD,H1)    5 input value =6.000000  Hex float16 = [c0,40]  ushort value=16576
TestCastBFloat16 (EURUSD,H1)    6 input value =7.000000  Hex float16 = [e0,40]  ushort value=16608
TestCastBFloat16 (EURUSD,H1)    7 input value =8.000000  Hex float16 = [00,41]  ushort value=16640
TestCastBFloat16 (EURUSD,H1)    8 input value =9.000000  Hex float16 = [10,41]  ushort value=16656
TestCastBFloat16 (EURUSD,H1)    9 input value =10.000000  Hex float16 = [20,41]  ushort value=16672
TestCastBFloat16 (EURUSD,H1)    10 input value =11.000000  Hex float16 = [30,41]  ushort value=16688
TestCastBFloat16 (EURUSD,H1)    11 input value =12.000000  Hex float16 = [40,41]  ushort value=16704
TestCastBFloat16 (EURUSD,H1)    ONNX input array:
TestCastBFloat16 (EURUSD,H1)         [uc] [value]
TestCastBFloat16 (EURUSD,H1)    [ 0]  ...   16256
TestCastBFloat16 (EURUSD,H1)    [ 1]  ...   16384
TestCastBFloat16 (EURUSD,H1)    [ 2]  ...   16448
TestCastBFloat16 (EURUSD,H1)    [ 3]  ...   16512
TestCastBFloat16 (EURUSD,H1)    [ 4]  ...   16544
TestCastBFloat16 (EURUSD,H1)    [ 5]  ...   16576
TestCastBFloat16 (EURUSD,H1)    [ 6]  ...   16608
TestCastBFloat16 (EURUSD,H1)    [ 7]  ...   16640
TestCastBFloat16 (EURUSD,H1)    [ 8]  ...   16656
TestCastBFloat16 (EURUSD,H1)    [ 9]  ...   16672
TestCastBFloat16 (EURUSD,H1)    [10]  ...   16688
TestCastBFloat16 (EURUSD,H1)    [11]  ...   16704
TestCastBFloat16 (EURUSD,H1)    ONNX output array:
TestCastBFloat16 (EURUSD,H1)         [uc]  [value]
TestCastBFloat16 (EURUSD,H1)    [ 0]  ...  1.00000
TestCastBFloat16 (EURUSD,H1)    [ 1]  ...  2.00000
TestCastBFloat16 (EURUSD,H1)    [ 2]  ...  3.00000
TestCastBFloat16 (EURUSD,H1)    [ 3]  ...  4.00000
TestCastBFloat16 (EURUSD,H1)    [ 4]  ...  5.00000
TestCastBFloat16 (EURUSD,H1)    [ 5]  ...  6.00000
TestCastBFloat16 (EURUSD,H1)    [ 6]  ...  7.00000
TestCastBFloat16 (EURUSD,H1)    [ 7]  ...  8.00000
TestCastBFloat16 (EURUSD,H1)    [ 8]  ...  9.00000
TestCastBFloat16 (EURUSD,H1)    [ 9]  ... 10.00000
TestCastBFloat16 (EURUSD,H1)    [10]  ... 11.00000
TestCastBFloat16 (EURUSD,H1)    [11]  ... 12.00000
TestCastBFloat16 (EURUSD,H1)    0 output float 1.000000 = [00,00,80,3f] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    1 output float 2.000000 = [00,00,00,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    2 output float 3.000000 = [00,00,40,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    3 output float 4.000000 = [00,00,80,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    4 output float 5.000000 = [00,00,a0,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    5 output float 6.000000 = [00,00,c0,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    6 output float 7.000000 = [00,00,e0,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    7 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    8 output float 9.000000 = [00,00,10,41] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    9 output float 10.000000 = [00,00,20,41] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    10 output float 11.000000 = [00,00,30,41] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    11 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    test=RunCastBFloat16ToFloat   sum_error=0.000000


1.2. FP8格式

现代语言模型可以包含数十亿个参数。使用FP16数字的训练模型已经被证明是有效的。从16位浮点数转换到FP8可以将内存需求减半,并加速训练和模型执行。

FP8格式(8位浮点数)是用于表示浮点数的数据类型之一。在FP8中,每个数字由8位数据表示,这些数据通常分为三个分量:符号、指数和尾数。这种格式提供了精度和存储效率之间的折衷,使其在需要节省内存和计算资源的应用程序中具有吸引力。

FP8的一个关键优势是它在处理大量数据方面的效率。由于其紧凑的数字表示,FP8减少了内存需求并加速了计算。这在处理大型数据集的机器学习和人工智能应用中尤为重要。

此外,FP8可用于实现诸如算术计算和信号处理之类的低级别操作。其紧凑的格式使其适用于资源有限的嵌入式系统和应用程序。然而,值得注意的是,FP8由于其有限的精度而有其局限性。在一些需要高精度计算的应用中,如科学计算或金融分析,FP8的使用可能不够。


1.2.1. FP8 格式 fp8_e5m2 and fp8_e4m3

2022年,发表了两篇文章,介绍了存储在一个字节中的浮点数,而float32数字存储在4个字节中。

在NVIDIA、Intel和ARM的文章“深度学习的FP8格式”(2022)中,根据IEEE规范介绍了两种类型。第一种类型是E4M3,符号为1位,指数为4位,尾数为3位。第二种类型是E5M2,符号为1位,指数为5位,尾数为2位。第一种类型通常用于权重,第二种类型用于梯度。

第二篇文章“深度神经网络的8位数字格式”介绍了类似的类型。IEEE标准将相同的值分配给+0(或整数0)和-0(或整数128)。文章建议为这两个数字分配不同的浮点值。此外,还探讨了指数和尾数之间的各种划分,表明E4M3和E5M2是最好的

因此,ONNX引入了4种新类型(从1.15.0版本开始):

  • E4M3FN:符号为1位,指数为4位,尾数为3位,只有NaN值,没有无限值(FN)。
  • E4M3FNUZ:符号为1位,指数为4位,尾数为3位,只有NaN值,没有无限值(FN),没有负零(UZ)。
  • E5M2:符号为1位,指数为5位,尾数为2位。
  • E5M2FNUZ:符号为1位,指数为5位,尾数为2位,只有NaN值,没有无限值(FN),没有负零(UZ)。

实现通常取决于硬件。NVIDIA、Intel和Arm实现E4M3FN,而E5M2在现代图形处理单元(GPU)中实现。GraphCore也有同样的功能,但使用了E4M3FNUZ和E5M2FNUZ。

让我们根据NVIDIA Hopper:H100和FP8支持一文简要总结一下FP8类型的主要信息。


图5. FP8格式的位表示

图5. FP8格式的位表示


表3. E5M2格式的浮点数字

表3. E5M2格式的浮点数字


表4. E4M3格式的浮点数字

表4. E4M3格式的浮点数字

正值范围的比较​​FP8_E4M3和FP8_E5M2的编号如图6所示。

图6. FP8正数范围的比较

图6. FP8正数值范围的比较(参考


FP8_E5M2和FP8_E4M3格式的数字算术运算(Add、Mul、Div)的精度比较如图7所示。

图7. float8_e5m2和float8_e4m3格式数字算术运算精度的比较

图7. float8_e5m2和float8_e4m3格式的数字算术运算精度的比较(参考

FP8格式数字的推荐用法:


1.2.2. 用于FLOAT8的ONNX运算符Cast的执行测试

此示例探讨了从各种类型的FLOAT8到float的转换

带有Cast操作的ONNX模型:


图8. MetaEditor中模型test_cast_FLOAT8E4M3FN_to_FLOAT.onx的输入和输出参数

图8. MetaEditor中模型test_cast_FLOAT8E4M3FN_to_FLOAT.onx的输入和输出参数



图9. MetaEditor中模型test_cast_FLOAT8E4M3FNUZ_to_FLOAT.onx的输入和输出参数

图9. MetaEditor中模型test_cast_FLOAT8E4M3FNUZ_to_FLOAT.onx的输入和输出参数


图10. MetaEditor中模型test_cast_FLOAT8E5M2_to_FLOAT.onx的输入和输出参数

图10. MetaEditor中模型test_cast_FLOAT8E5M2_to_FLOAT.onx的输入和输出参数


图11. MetaEditor中模型test_cast_FLOAT8E5M2FNUZ_to_FLOAT.onx的输入和输出参数

图11. MetaEditor中模型test_cast_FLOAT8E5M2FNUZ_to_FLOAT.onx的输入和输出参数


示例:

//+------------------------------------------------------------------+
//|                                              TestCastBFloat8.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"

#resource "models\\test_cast_FLOAT8E4M3FN_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E4M3FN_to_FLOAT[];
#resource "models\\test_cast_FLOAT8E4M3FNUZ_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E4M3FNUZ_to_FLOAT[];
#resource "models\\test_cast_FLOAT8E5M2_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E5M2_to_FLOAT[];
#resource "models\\test_cast_FLOAT8E5M2FNUZ_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E5M2FNUZ_to_FLOAT[];

#define TEST_PASSED 0
#define TEST_FAILED 1
//+------------------------------------------------------------------+
//| union for data conversion                                        |
//+------------------------------------------------------------------+
template<typename T>
union U
  {
   uchar uc[sizeof(T)];
   T value;
  };
//+------------------------------------------------------------------+
//| ArrayToHexString                                                 |
//+------------------------------------------------------------------+
template<typename T>
string ArrayToHexString(const T &data[],uint length=16)
  {
   string res;

   for(uint n=0; n<MathMin(length,data.Size()); n++)
      res+="," + StringFormat("%.2x",data[n]);

   StringSetCharacter(res,0,'[');
   return(res+"]");
  }
//+------------------------------------------------------------------+
//| ArrayToString                                                    |
//+------------------------------------------------------------------+
template<typename T>
string ArrayToString(const U<T> &data[],uint length=16)
  {
   string res;

   for(uint n=0; n<MathMin(length,data.Size()); n++)
      res+="," + (string)data[n].value;

   StringSetCharacter(res,0,'[');
   return(res+"]");
  }
//+------------------------------------------------------------------+
//| PatchONNXModel                                                   |
//+------------------------------------------------------------------+
long CreatePatchedModel(const uchar &original_model[])
  {
   uchar patched_model[];
   ArrayCopy(patched_model,original_model);
//--- special ONNX model patch(IR=9,Opset=20)
   patched_model[1]=0x09;
   patched_model[ArraySize(patched_model)-1]=0x14;

   return(OnnxCreateFromBuffer(patched_model,ONNX_DEFAULT));
  }
//+------------------------------------------------------------------+
//| PrepareShapes                                                    |
//+------------------------------------------------------------------+
bool PrepareShapes(long model_handle)
  {
//--- configure input shape
   ulong input_shape[]= {3,5};

   if(!OnnxSetInputShape(model_handle,0,input_shape))
     {
      PrintFormat("error in OnnxSetInputShape for input1. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(false);
     }
//--- configure output shape
   ulong output_shape[]= {3,5};

   if(!OnnxSetOutputShape(model_handle,0,output_shape))
     {
      PrintFormat("error in OnnxSetOutputShape for output. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(false);
     }

   return(true);
  }
//+------------------------------------------------------------------+
//| RunCastFloat8Float                                               |
//+------------------------------------------------------------------+
bool RunCastFloat8ToFloat(long model_handle,const ENUM_FLOAT8_FORMAT fmt)
  {
   PrintFormat("TEST: %s(%s)",__FUNCTION__,EnumToString(fmt));
//---
   float test_data[15]   = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15};
   uchar data_float8[15] = {};

   if(!ArrayToFP8(data_float8,test_data,fmt))
     {
      Print("error in ArrayToFP8. error code=",GetLastError());
      OnnxRelease(model_handle);
      return(false);
     }

   U<uchar> input_float8_values[3*5];
   U<float> output_float_values[3*5];
   float    test_data_float[];
//--- convert float8 to float
   if(!ArrayFromFP8(test_data_float,data_float8,fmt))
     {
      Print("error in ArrayFromFP8. error code=",GetLastError());
      OnnxRelease(model_handle);
      return(false);
     }

   for(uint i=0; i<data_float8.Size(); i++)
     {
      input_float8_values[i].value=data_float8[i];
      PrintFormat("%d input value =%f  Hex float8 = %s  ushort value=%d",i,test_data_float[i],ArrayToHexString(input_float8_values[i].uc),input_float8_values[i].value);
     }

   Print("ONNX input array: ",ArrayToString(input_float8_values));
//--- execute model (convert float8 to float using ONNX)
   if(!OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float8_values,output_float_values))
     {
      PrintFormat("error in OnnxRun. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(false);
     }

   Print("ONNX output array: ",ArrayToString(output_float_values));
//--- calculate error (compare ONNX and ArrayFromFP8 results)
   double sum_error=0.0;

   for(uint i=0; i<test_data.Size(); i++)
     {
      double delta=test_data_float[i]-(double)output_float_values[i].value;
      sum_error+=MathAbs(delta);
      PrintFormat("%d output float %f = %s difference=%f",i,output_float_values[i].value,ArrayToHexString(output_float_values[i].uc),delta);
     }
//---
   PrintFormat("%s(%s): sum_error=%f\n",__FUNCTION__,EnumToString(fmt),sum_error);
   return(true);
  }
//+------------------------------------------------------------------+
//| TestModel                                                        |
//+------------------------------------------------------------------+
bool TestModel(const uchar &model[],const ENUM_FLOAT8_FORMAT fmt)
  {
//--- create patched model
   long model_handle=CreatePatchedModel(model);

   if(model_handle==INVALID_HANDLE)
      return(false);
//--- prepare input and output shapes
   if(!PrepareShapes(model_handle))
      return(false);
//--- run ONNX model
   if(!RunCastFloat8ToFloat(model_handle,fmt))
      return(false);
//--- release model handle
   OnnxRelease(model_handle);

   return(true);
  }
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
int OnStart(void)
  {
//--- run ONNX model
   if(!TestModel(ExtModel_FLOAT8E4M3FN_to_FLOAT,FLOAT_FP8_E4M3FN))
      return(TEST_FAILED);

//--- run ONNX model
   if(!TestModel(ExtModel_FLOAT8E4M3FNUZ_to_FLOAT,FLOAT_FP8_E4M3FNUZ))
      return(TEST_FAILED);

//--- run ONNX model
   if(!TestModel(ExtModel_FLOAT8E5M2_to_FLOAT,FLOAT_FP8_E5M2FN))
      return(TEST_FAILED);

//--- run ONNX model
   if(!TestModel(ExtModel_FLOAT8E5M2FNUZ_to_FLOAT,FLOAT_FP8_E5M2FNUZ))
      return(TEST_FAILED);

   return(TEST_PASSED);
  }
//+------------------------------------------------------------------+
输出:
TestCastFloat8 (EURUSD,H1)      TEST: RunCastFloat8ToFloat(FLOAT_FP8_E4M3FN)
TestCastFloat8 (EURUSD,H1)      0 input value =1.000000  Hex float8 = [38]  ushort value=56
TestCastFloat8 (EURUSD,H1)      1 input value =2.000000  Hex float8 = [40]  ushort value=64
TestCastFloat8 (EURUSD,H1)      2 input value =3.000000  Hex float8 = [44]  ushort value=68
TestCastFloat8 (EURUSD,H1)      3 input value =4.000000  Hex float8 = [48]  ushort value=72
TestCastFloat8 (EURUSD,H1)      4 input value =5.000000  Hex float8 = [4a]  ushort value=74
TestCastFloat8 (EURUSD,H1)      5 input value =6.000000  Hex float8 = [4c]  ushort value=76
TestCastFloat8 (EURUSD,H1)      6 input value =7.000000  Hex float8 = [4e]  ushort value=78
TestCastFloat8 (EURUSD,H1)      7 input value =8.000000  Hex float8 = [50]  ushort value=80
TestCastFloat8 (EURUSD,H1)      8 input value =9.000000  Hex float8 = [51]  ushort value=81
TestCastFloat8 (EURUSD,H1)      9 input value =10.000000  Hex float8 = [52]  ushort value=82
TestCastFloat8 (EURUSD,H1)      10 input value =11.000000  Hex float8 = [53]  ushort value=83
TestCastFloat8 (EURUSD,H1)      11 input value =12.000000  Hex float8 = [54]  ushort value=84
TestCastFloat8 (EURUSD,H1)      12 input value =13.000000  Hex float8 = [55]  ushort value=85
TestCastFloat8 (EURUSD,H1)      13 input value =14.000000  Hex float8 = [56]  ushort value=86
TestCastFloat8 (EURUSD,H1)      14 input value =15.000000  Hex float8 = [57]  ushort value=87
TestCastFloat8 (EURUSD,H1)      ONNX input array: [56,64,68,72,74,76,78,80,81,82,83,84,85,86,87]
TestCastFloat8 (EURUSD,H1)      ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0]
TestCastFloat8 (EURUSD,H1)      0 output float 1.000000 = [00,00,80,3f] difference=0.000000
TestCastFloat8 (EURUSD,H1)      1 output float 2.000000 = [00,00,00,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      2 output float 3.000000 = [00,00,40,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      3 output float 4.000000 = [00,00,80,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      4 output float 5.000000 = [00,00,a0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      5 output float 6.000000 = [00,00,c0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      6 output float 7.000000 = [00,00,e0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      7 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      8 output float 9.000000 = [00,00,10,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      9 output float 10.000000 = [00,00,20,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      10 output float 11.000000 = [00,00,30,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      11 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      12 output float 13.000000 = [00,00,50,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      13 output float 14.000000 = [00,00,60,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      14 output float 15.000000 = [00,00,70,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      RunCastFloat8ToFloat(FLOAT_FP8_E4M3FN): sum_error=0.000000
TestCastFloat8 (EURUSD,H1)      
TestCastFloat8 (EURUSD,H1)      TEST: RunCastFloat8ToFloat(FLOAT_FP8_E4M3FNUZ)
TestCastFloat8 (EURUSD,H1)      0 input value =1.000000  Hex float8 = [40]  ushort value=64
TestCastFloat8 (EURUSD,H1)      1 input value =2.000000  Hex float8 = [48]  ushort value=72
TestCastFloat8 (EURUSD,H1)      2 input value =3.000000  Hex float8 = [4c]  ushort value=76
TestCastFloat8 (EURUSD,H1)      3 input value =4.000000  Hex float8 = [50]  ushort value=80
TestCastFloat8 (EURUSD,H1)      4 input value =5.000000  Hex float8 = [52]  ushort value=82
TestCastFloat8 (EURUSD,H1)      5 input value =6.000000  Hex float8 = [54]  ushort value=84
TestCastFloat8 (EURUSD,H1)      6 input value =7.000000  Hex float8 = [56]  ushort value=86
TestCastFloat8 (EURUSD,H1)      7 input value =8.000000  Hex float8 = [58]  ushort value=88
TestCastFloat8 (EURUSD,H1)      8 input value =9.000000  Hex float8 = [59]  ushort value=89
TestCastFloat8 (EURUSD,H1)      9 input value =10.000000  Hex float8 = [5a]  ushort value=90
TestCastFloat8 (EURUSD,H1)      10 input value =11.000000  Hex float8 = [5b]  ushort value=91
TestCastFloat8 (EURUSD,H1)      11 input value =12.000000  Hex float8 = [5c]  ushort value=92
TestCastFloat8 (EURUSD,H1)      12 input value =13.000000  Hex float8 = [5d]  ushort value=93
TestCastFloat8 (EURUSD,H1)      13 input value =14.000000  Hex float8 = [5e]  ushort value=94
TestCastFloat8 (EURUSD,H1)      14 input value =15.000000  Hex float8 = [5f]  ushort value=95
TestCastFloat8 (EURUSD,H1)      ONNX input array: [64,72,76,80,82,84,86,88,89,90,91,92,93,94,95]
TestCastFloat8 (EURUSD,H1)      ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0]
TestCastFloat8 (EURUSD,H1)      0 output float 1.000000 = [00,00,80,3f] difference=0.000000
TestCastFloat8 (EURUSD,H1)      1 output float 2.000000 = [00,00,00,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      2 output float 3.000000 = [00,00,40,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      3 output float 4.000000 = [00,00,80,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      4 output float 5.000000 = [00,00,a0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      5 output float 6.000000 = [00,00,c0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      6 output float 7.000000 = [00,00,e0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      7 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      8 output float 9.000000 = [00,00,10,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      9 output float 10.000000 = [00,00,20,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      10 output float 11.000000 = [00,00,30,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      11 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      12 output float 13.000000 = [00,00,50,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      13 output float 14.000000 = [00,00,60,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      14 output float 15.000000 = [00,00,70,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      RunCastFloat8ToFloat(FLOAT_FP8_E4M3FNUZ): sum_error=0.000000
TestCastFloat8 (EURUSD,H1)      
TestCastFloat8 (EURUSD,H1)      TEST: RunCastFloat8ToFloat(FLOAT_FP8_E5M2FN)
TestCastFloat8 (EURUSD,H1)      0 input value =1.000000  Hex float8 = [3c]  ushort value=60
TestCastFloat8 (EURUSD,H1)      1 input value =2.000000  Hex float8 = [40]  ushort value=64
TestCastFloat8 (EURUSD,H1)      2 input value =3.000000  Hex float8 = [42]  ushort value=66
TestCastFloat8 (EURUSD,H1)      3 input value =4.000000  Hex float8 = [44]  ushort value=68
TestCastFloat8 (EURUSD,H1)      4 input value =5.000000  Hex float8 = [45]  ushort value=69
TestCastFloat8 (EURUSD,H1)      5 input value =6.000000  Hex float8 = [46]  ushort value=70
TestCastFloat8 (EURUSD,H1)      6 input value =7.000000  Hex float8 = [47]  ushort value=71
TestCastFloat8 (EURUSD,H1)      7 input value =8.000000  Hex float8 = [48]  ushort value=72
TestCastFloat8 (EURUSD,H1)      8 input value =8.000000  Hex float8 = [48]  ushort value=72
TestCastFloat8 (EURUSD,H1)      9 input value =10.000000  Hex float8 = [49]  ushort value=73
TestCastFloat8 (EURUSD,H1)      10 input value =12.000000  Hex float8 = [4a]  ushort value=74
TestCastFloat8 (EURUSD,H1)      11 input value =12.000000  Hex float8 = [4a]  ushort value=74
TestCastFloat8 (EURUSD,H1)      12 input value =12.000000  Hex float8 = [4a]  ushort value=74
TestCastFloat8 (EURUSD,H1)      13 input value =14.000000  Hex float8 = [4b]  ushort value=75
TestCastFloat8 (EURUSD,H1)      14 input value =16.000000  Hex float8 = [4c]  ushort value=76
TestCastFloat8 (EURUSD,H1)      ONNX input array: [60,64,66,68,69,70,71,72,72,73,74,74,74,75,76]
TestCastFloat8 (EURUSD,H1)      ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,8.0,10.0,12.0,12.0,12.0,14.0,16.0]
TestCastFloat8 (EURUSD,H1)      0 output float 1.000000 = [00,00,80,3f] difference=0.000000
TestCastFloat8 (EURUSD,H1)      1 output float 2.000000 = [00,00,00,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      2 output float 3.000000 = [00,00,40,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      3 output float 4.000000 = [00,00,80,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      4 output float 5.000000 = [00,00,a0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      5 output float 6.000000 = [00,00,c0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      6 output float 7.000000 = [00,00,e0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      7 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      8 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      9 output float 10.000000 = [00,00,20,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      10 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      11 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      12 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      13 output float 14.000000 = [00,00,60,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      14 output float 16.000000 = [00,00,80,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      RunCastFloat8ToFloat(FLOAT_FP8_E5M2FN): sum_error=0.000000
TestCastFloat8 (EURUSD,H1)      
TestCastFloat8 (EURUSD,H1)      TEST: RunCastFloat8ToFloat(FLOAT_FP8_E5M2FNUZ)
TestCastFloat8 (EURUSD,H1)      0 input value =1.000000  Hex float8 = [40]  ushort value=64
TestCastFloat8 (EURUSD,H1)      1 input value =2.000000  Hex float8 = [44]  ushort value=68
TestCastFloat8 (EURUSD,H1)      2 input value =3.000000  Hex float8 = [46]  ushort value=70
TestCastFloat8 (EURUSD,H1)      3 input value =4.000000  Hex float8 = [48]  ushort value=72
TestCastFloat8 (EURUSD,H1)      4 input value =5.000000  Hex float8 = [49]  ushort value=73
TestCastFloat8 (EURUSD,H1)      5 input value =6.000000  Hex float8 = [4a]  ushort value=74
TestCastFloat8 (EURUSD,H1)      6 input value =7.000000  Hex float8 = [4b]  ushort value=75
TestCastFloat8 (EURUSD,H1)      7 input value =8.000000  Hex float8 = [4c]  ushort value=76
TestCastFloat8 (EURUSD,H1)      8 input value =8.000000  Hex float8 = [4c]  ushort value=76
TestCastFloat8 (EURUSD,H1)      9 input value =10.000000  Hex float8 = [4d]  ushort value=77
TestCastFloat8 (EURUSD,H1)      10 input value =12.000000  Hex float8 = [4e]  ushort value=78
TestCastFloat8 (EURUSD,H1)      11 input value =12.000000  Hex float8 = [4e]  ushort value=78
TestCastFloat8 (EURUSD,H1)      12 input value =12.000000  Hex float8 = [4e]  ushort value=78
TestCastFloat8 (EURUSD,H1)      13 input value =14.000000  Hex float8 = [4f]  ushort value=79
TestCastFloat8 (EURUSD,H1)      14 input value =16.000000  Hex float8 = [50]  ushort value=80
TestCastFloat8 (EURUSD,H1)      ONNX input array: [64,68,70,72,73,74,75,76,76,77,78,78,78,79,80]
TestCastFloat8 (EURUSD,H1)      ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,8.0,10.0,12.0,12.0,12.0,14.0,16.0]
TestCastFloat8 (EURUSD,H1)      0 output float 1.000000 = [00,00,80,3f] difference=0.000000
TestCastFloat8 (EURUSD,H1)      1 output float 2.000000 = [00,00,00,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      2 output float 3.000000 = [00,00,40,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      3 output float 4.000000 = [00,00,80,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      4 output float 5.000000 = [00,00,a0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      5 output float 6.000000 = [00,00,c0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      6 output float 7.000000 = [00,00,e0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      7 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      8 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      9 output float 10.000000 = [00,00,20,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      10 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      11 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      12 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      13 output float 14.000000 = [00,00,60,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      14 output float 16.000000 = [00,00,80,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      RunCastFloat8ToFloat(FLOAT_FP8_E5M2FNUZ): sum_error=0.000000
TestCastFloat8 (EURUSD,H1)



2. 使用ONNX实现图像超分辨率

在本节中,我们将探讨使用SRGAN模型来提高图像分辨率的示例。

ESRGAN,或称增强型超分辨率生成对抗性网络(Enhanced Super-Resolution Generative Adversarial Networks),是一种强大的神经网络架构,旨在解决图像超分辨率任务。ESRGAN是为了通过将分辨率提高到更高的水平来提高图像质量而开发的。这是通过在低分辨率图像及其相应的高质量图像的大型数据集上训练深度神经网络来实现的。ESRGAN采用生成对抗性网络(GANs)的架构,该架构由两个主要组件组成:生成器和鉴别器。生成器负责创建高分辨率图像,而鉴别器则被训练来区分生成的图像和真实图像。

ESRGAN架构的核心是残差块,它有助于提取和保存不同抽象级别的重要图像特征。这使得网络能够有效地恢复高质量图像中的细节和纹理。

为了在解决超分辨率任务时实现高质量和通用性,ESRGAN需要大量的训练数据集。这允许网络学习图像的各种风格和特征,使其更适合不同类型的输入数据。ESRGAN可用于提高许多领域的图像质量,包括摄影、医学诊断、电影和视频制作、图形设计等。其灵活性和效率使其成为图像超分辨率领域的领先方法之一

ESRGAN代表了图像处理和人工智能领域的重大进步,为创建和增强图像开辟了新的能力。


2.1. 使用float32执行ONNX模型

要执行该示例,您需要下载文件https://github.com/amannm/super-resolution-service/blob/main/models/esrgan.onnx并将其复制到文件夹\MQL5\Scripts\models中。

ESRGAN.onnx模型包含约1200个ONNX操作,其初始操作如图12所示。

图12.

图12. MetaEditor中的ESRGAN.onnx模型描述



图13. Netron中的ESRGAN.ONNX模型

图13. Netron中的ESRGAN.ONNX模型


下面提供的代码演示了使用ESRGAN.onnx将图像放大4倍.

它首先加载esrgan.onx模型,然后选择并加载BMP格式的原始图像。之后,将图像转换为单独的RGB通道,然后将这些通道作为输入参数输入到模型中。该模型执行将图像放大4倍的过程,之后所得到的放大图像进行逆变换并准备显示。

Canvas库用于显示,ONNX Runtime库用于模型执行。程序执行后,放大后的图像将保存到原始文件名后附加“_upscaled”的文件中。关键功能包括图像预处理和后处理,以及用于图像放大的模型执行。

//+------------------------------------------------------------------+
//|                                                       ESRGAN.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"
//+------------------------------------------------------------------+
//| 4x image upscaling demo using ESRGAN                             |
//| esrgan.onnx model from                                           |
//| https://github.com/amannm/super-resolution-service/              |
//+------------------------------------------------------------------+
//| Xintao Wang et al (2018)                                         |
//| ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks|
//| https://arxiv.org/abs/1809.00219                                 |
//+------------------------------------------------------------------+
#resource "models\\esrgan.onnx" as uchar ExtModel[];
#include <Canvas\Canvas.mqh>
//+------------------------------------------------------------------+
//| clamp                                                            |
//+------------------------------------------------------------------+
float clamp(float value, float minValue, float maxValue)
  {
   return MathMin(MathMax(value, minValue), maxValue);
  }
//+------------------------------------------------------------------+
//| Preprocessing                                                    |
//+------------------------------------------------------------------+
bool Preprocessing(float &data[],uint &image_data[],int &image_width,int &image_height)
  {
//--- checkup
   if(image_height==0 || image_width==0)
      return(false);
//--- prepare destination array with separated RGB channels for ONNX model
   int data_count=3*image_width*image_height;

   if(ArrayResize(data,data_count)!=data_count)
     {
      Print("ArrayResize failed");
      return(false);
     }
//--- converting
   for(int y=0; y<image_height; y++)
      for(int x=0; x<image_width; x++)
        {
         //--- load source RGB
         int   offset=y*image_width+x;
         uint  clr   =image_data[offset];
         uchar r     =GETRGBR(clr);
         uchar g     =GETRGBG(clr);
         uchar b     =GETRGBB(clr);
         //--- store RGB components as separated channels
         int offset_ch1=0*image_width*image_height+offset;
         int offset_ch2=1*image_width*image_height+offset;
         int offset_ch3=2*image_width*image_height+offset;

         data[offset_ch1]=r/255.0f;
         data[offset_ch2]=g/255.0f;
         data[offset_ch3]=b/255.0f;
        }
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| PostProcessing                                                   |
//+------------------------------------------------------------------+
bool PostProcessing(const float &data[], uint &image_data[], const int &image_width, const int &image_height)
  {
//--- checks
   if(image_height == 0 || image_width == 0)
      return(false);

   int data_count=image_width*image_height;
   
   if(ArraySize(data)!=3*data_count)
      return(false);
   if(ArrayResize(image_data,data_count)!=data_count)
      return(false);
//---
   for(int y=0; y<image_height; y++)
      for(int x=0; x<image_width; x++)
        {
         int offset    =y*image_width+x;
         int offset_ch1=0*image_width*image_height+offset;
         int offset_ch2=1*image_width*image_height+offset;
         int offset_ch3=2*image_width*image_height+offset;
         //--- rescale to [0..255]
         float r=clamp(data[offset_ch1]*255,0,255);
         float g=clamp(data[offset_ch2]*255,0,255);
         float b=clamp(data[offset_ch3]*255,0,255);
         //--- set color image_data
         image_data[offset]=XRGB(uchar(r),uchar(g),uchar(b));
        }
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| ShowImage                                                        |
//+------------------------------------------------------------------+
bool ShowImage(CCanvas &canvas,const string name,const int x0,const int y0,const int image_width,const int image_height, const uint &image_data[])
  {
   if(ArraySize(image_data)==0 || name=="")
      return(false);
//--- prepare canvas
   canvas.CreateBitmapLabel(name,x0,y0,image_width,image_height,COLOR_FORMAT_XRGB_NOALPHA);
//--- copy image to canvas
   for(int y=0; y<image_height; y++)
      for(int x=0; x<image_width; x++)
         canvas.PixelSet(x,y,image_data[y*image_width+x]);
//--- ready to draw
   canvas.Update(true);
   return(true);
  }
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
int OnStart(void)
  {
//--- select BMP from <data folder>\MQL5\Files
   string image_path[1];

   if(FileSelectDialog("Select BMP image",NULL,"Bitmap files (*.bmp)|*.bmp",FSD_FILE_MUST_EXIST,image_path,"lenna-original4.bmp")!=1)
     {
      Print("file not selected");
      return(-1);
     }
//--- load BMP into array
   uint image_data[];
   int  image_width;
   int  image_height;

   if(!CCanvas::LoadBitmap(image_path[0],image_data,image_width,image_height))
     {
      PrintFormat("CCanvas::LoadBitmap failed with error %d",GetLastError());
      return(-1);
     }
//--- convert RGB image to separated RGB channels
   float input_data[];
   Preprocessing(input_data,image_data,image_width,image_height);
   PrintFormat("input array size=%d",ArraySize(input_data));
//--- load model
   long model_handle=OnnxCreateFromBuffer(ExtModel,ONNX_DEFAULT);

   if(model_handle==INVALID_HANDLE)
     {
      PrintFormat("OnnxCreate error %d",GetLastError());
      return(-1);
     }

   PrintFormat("model loaded successfully");
   PrintFormat("original:  width=%d, height=%d  Size=%d",image_width,image_height,ArraySize(image_data));
//--- set input shape
   ulong input_shape[]={1,3,image_height,image_width};

   if(!OnnxSetInputShape(model_handle,0,input_shape))
     {
      PrintFormat("error in OnnxSetInputShape. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(-1);
     }
//--- upscaled image size
   int   new_image_width =4*image_width;
   int   new_image_height=4*image_height;
   ulong output_shape[]= {1,3,new_image_height,new_image_width};

   if(!OnnxSetOutputShape(model_handle,0,output_shape))
     {
      PrintFormat("error in OnnxSetOutputShape. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(-1);
     }
//--- run the model
   float output_data[];
   int new_data_count=3*new_image_width*new_image_height;
   if(ArrayResize(output_data,new_data_count)!=new_data_count)
     {
      OnnxRelease(model_handle);
      return(-1);
     }

   if(!OnnxRun(model_handle,ONNX_DEBUG_LOGS,input_data,output_data))
     {
      PrintFormat("error in OnnxRun. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(-1);
     }

   Print("model successfully executed, output data size ",ArraySize(output_data));
   OnnxRelease(model_handle);
//--- postprocessing
   uint new_image[];
   PostProcessing(output_data,new_image,new_image_width,new_image_height);
//--- show images
   CCanvas canvas_original,canvas_scaled;
   ShowImage(canvas_original,"original_image",new_image_width,0,image_width,image_height,image_data);
   ShowImage(canvas_scaled,"upscaled_image",0,0,new_image_width,new_image_height,new_image);
//--- save upscaled image
   StringReplace(image_path[0],".bmp","_upscaled.bmp");
   Print(ResourceSave(canvas_scaled.ResourceName(),image_path[0]));
//---
   while(!IsStopped())
      Sleep(100);

   return(0);
  }
//+------------------------------------------------------------------+

输出:

图14. ESRGAN.onnx模型执行的结果

图14. ESRGAN.onnx模型执行的结果(160x200->640x800)

在本例中,使用ESRGAN.onnx模型将160x200图像放大四倍(至640x800)。


2.2. 使用float16执行ONNX模型的示例

要将模型转换为float16,我们将使用创建float16和混合精度模型中描述的方法。

# Copyright 2024, MetaQuotes Ltd.
# https://www.mql5.com

import onnx
from onnxconverter_common import float16

from sys import argv

# Define the path for saving the model
data_path = argv[0]
last_index = data_path.rfind("\\") + 1
data_path = data_path[0:last_index]

# convert the model to float16
model_path = data_path+'\\models\\esrgan.onnx'
modelfp16_path = data_path+'\\models\\esrgan_float16.onnx'

model = onnx.load(model_path)
model_fp16 = float16.convert_float_to_float16(model)
onnx.save(model_fp16, modelfp16_path)

转换后,文件大小减少了一半(从64MB减少到32MB)。

代码中的更改非常小。

//+------------------------------------------------------------------+
//|                                               ESRGAN_float16.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"
//+------------------------------------------------------------------+
//| 4x image upscaling demo using ESRGAN                             |
//| esrgan.onnx model from                                           |
//| https://github.com/amannm/super-resolution-service/              |
//+------------------------------------------------------------------+
//| Xintao Wang et al (2018)                                         |
//| ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks|
//| https://arxiv.org/abs/1809.00219                                 |
//+------------------------------------------------------------------+
#resource "models\\esrgan_float16.onnx" as uchar ExtModel[];
#include <Canvas\Canvas.mqh>
//+------------------------------------------------------------------+
//| clamp                                                            |
//+------------------------------------------------------------------+
float clamp(float value, float minValue, float maxValue)
  {
   return MathMin(MathMax(value, minValue), maxValue);
  }
//+------------------------------------------------------------------+
//| Preprocessing                                                    |
//+------------------------------------------------------------------+
bool Preprocessing(float &data[],uint &image_data[],int &image_width,int &image_height)
  {
//--- checkup
   if(image_height==0 || image_width==0)
      return(false);
//--- prepare destination array with separated RGB channels for ONNX model
   int data_count=3*image_width*image_height;

   if(ArrayResize(data,data_count)!=data_count)
     {
      Print("ArrayResize failed");
      return(false);
     }
//--- converting
   for(int y=0; y<image_height; y++)
      for(int x=0; x<image_width; x++)
        {
         //--- load source RGB
         int   offset=y*image_width+x;
         uint  clr   =image_data[offset];
         uchar r     =GETRGBR(clr);
         uchar g     =GETRGBG(clr);
         uchar b     =GETRGBB(clr);
         //--- store RGB components as separated channels
         int offset_ch1=0*image_width*image_height+offset;
         int offset_ch2=1*image_width*image_height+offset;
         int offset_ch3=2*image_width*image_height+offset;

         data[offset_ch1]=r/255.0f;
         data[offset_ch2]=g/255.0f;
         data[offset_ch3]=b/255.0f;
        }
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| PostProcessing                                                   |
//+------------------------------------------------------------------+
bool PostProcessing(const float &data[], uint &image_data[], const int &image_width, const int &image_height)
  {
//--- checks
   if(image_height == 0 || image_width == 0)
      return(false);

   int data_count=image_width*image_height;
   
   if(ArraySize(data)!=3*data_count)
      return(false);
   if(ArrayResize(image_data,data_count)!=data_count)
      return(false);
//---
   for(int y=0; y<image_height; y++)
      for(int x=0; x<image_width; x++)
        {
         int offset    =y*image_width+x;
         int offset_ch1=0*image_width*image_height+offset;
         int offset_ch2=1*image_width*image_height+offset;
         int offset_ch3=2*image_width*image_height+offset;
         //--- rescale to [0..255]
         float r=clamp(data[offset_ch1]*255,0,255);
         float g=clamp(data[offset_ch2]*255,0,255);
         float b=clamp(data[offset_ch3]*255,0,255);
         //--- set color image_data
         image_data[offset]=XRGB(uchar(r),uchar(g),uchar(b));
        }
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| ShowImage                                                        |
//+------------------------------------------------------------------+
bool ShowImage(CCanvas &canvas,const string name,const int x0,const int y0,const int image_width,const int image_height, const uint &image_data[])
  {
   if(ArraySize(image_data)==0 || name=="")
      return(false);
//--- prepare canvas
   canvas.CreateBitmapLabel(name,x0,y0,image_width,image_height,COLOR_FORMAT_XRGB_NOALPHA);
//--- copy image to canvas
   for(int y=0; y<image_height; y++)
      for(int x=0; x<image_width; x++)
         canvas.PixelSet(x,y,image_data[y*image_width+x]);
//--- ready to draw
   canvas.Update(true);
   return(true);
  }
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
int OnStart(void)
  {
//--- select BMP from <data folder>\MQL5\Files
   string image_path[1];

   if(FileSelectDialog("Select BMP image",NULL,"Bitmap files (*.bmp)|*.bmp",FSD_FILE_MUST_EXIST,image_path,"lenna.bmp")!=1)
     {
      Print("file not selected");
      return(-1);
     }
//--- load BMP into array
   uint image_data[];
   int  image_width;
   int  image_height;

   if(!CCanvas::LoadBitmap(image_path[0],image_data,image_width,image_height))
     {
      PrintFormat("CCanvas::LoadBitmap failed with error %d",GetLastError());
      return(-1);
     }
//--- convert RGB image to separated RGB channels
   float input_data[];
   Preprocessing(input_data,image_data,image_width,image_height);
   PrintFormat("input array size=%d",ArraySize(input_data));
   
   ushort input_data_float16[];
   if(!ArrayToFP16(input_data_float16,input_data,FLOAT_FP16))
     {
      Print("error in ArrayToFP16. error code=",GetLastError());
      return(false);
     }   
//--- load model
   long model_handle=OnnxCreateFromBuffer(ExtModel,ONNX_DEFAULT);
   if(model_handle==INVALID_HANDLE)
     {
      PrintFormat("OnnxCreate error %d",GetLastError());
      return(-1);
     }

   PrintFormat("model loaded successfully");
   PrintFormat("original:  width=%d, height=%d  Size=%d",image_width,image_height,ArraySize(image_data));
//--- set input shape
   ulong input_shape[]={1,3,image_height,image_width};

   if(!OnnxSetInputShape(model_handle,0,input_shape))
     {
      PrintFormat("error in OnnxSetInputShape. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(-1);
     }
//--- upscaled image size
   int   new_image_width =4*image_width;
   int   new_image_height=4*image_height;
   ulong output_shape[]= {1,3,new_image_height,new_image_width};

   if(!OnnxSetOutputShape(model_handle,0,output_shape))
     {
      PrintFormat("error in OnnxSetOutputShape. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(-1);
     }
//--- run the model
   float output_data[];
   ushort output_data_float16[];
   int new_data_count=3*new_image_width*new_image_height;
   if(ArrayResize(output_data_float16,new_data_count)!=new_data_count)
     {
      OnnxRelease(model_handle);
      return(-1);
     }

   if(!OnnxRun(model_handle,ONNX_NO_CONVERSION,input_data_float16,output_data_float16))
     {
      PrintFormat("error in OnnxRun. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(-1);
     }

   Print("model successfully executed, output data size ",ArraySize(output_data));
   OnnxRelease(model_handle);
   
   if(!ArrayFromFP16(output_data,output_data_float16,FLOAT_FP16))
     {
      Print("error in ArrayFromFP16. error code=",GetLastError());
      return(false);
     }   
//--- postprocessing
   uint new_image[];
   PostProcessing(output_data,new_image,new_image_width,new_image_height);
//--- show images
   CCanvas canvas_original,canvas_scaled;
   ShowImage(canvas_original,"original_image",new_image_width,0,image_width,image_height,image_data);
   ShowImage(canvas_scaled,"upscaled_image",0,0,new_image_width,new_image_height,new_image);
//--- save upscaled image
   StringReplace(image_path[0],".bmp","_upscaled.bmp");
   Print(ResourceSave(canvas_scaled.ResourceName(),image_path[0]));
//---
   while(!IsStopped())
      Sleep(100);

   return(0);
  }
//+------------------------------------------------------------------+

执行转换为float16格式的模型所需的代码更改以颜色突出显示。

输出:

图15. ESRGAN_float16.onx模型执行的结果

图15. ESRGAN_float16.onx模型执行的结果(160x200->640x800)


因此,使用float16数字而不是float32可以将ONNX模型文件的大小减少一半(从64MB减少到32MB)。

当使用float16数字执行模型时,图像质量保持不变,因此在视觉上很难找到差异:


图16. float和float16类型在ESRGAN模型运行结果的比较

Fig.16. float和float16类型在ESRGAN模型运行结果的比较


代码中的更改是非常小的,只需要注意输入和输出数据的转换。

在这种情况下,转换为float16后,模型的性能质量没有显著变化。然而,在分析金融数据时,必须努力获得尽可能高的准确性。


结论

对浮点数使用新的数据类型可以减少ONNX模型的大小,而不会显著降低质量。

通过使用转换函数ArrayToFP16/ArrayFromFP16和ArrayToFP8/ArrayFromFP8,大大简化了数据的预处理和后处理。

使用转换后的ONNX模型需要对代码进行最小的更改。