Работа с ONNX-моделями в форматах float16 и float8

MetaQuotes | 27 февраля, 2024

Содержание

С развитием технологий машинного обучения и искусственного интеллекта возникает необходимость в оптимизации процессов работы с моделями. Эффективность работы моделей напрямую зависит от форматов данных, используемых для их представления. В последние годы появилось несколько новых типов данных, предназначенных специально для работы с моделями глубокого обучения.

В данной статье мы сосредоточимся на двух таких новых форматах данных — float16 и float8, которые начинают активно использоваться в современных ONNX-моделях. Эти форматы представляют собой альтернативные варианты более точных, но требовательных к ресурсам форматам данных с плавающей точкой. Они обеспечивают оптимальное сочетание производительности и точности, что делает их особенно привлекательными для различных задач машинного обучения. Мы изучим основные характеристики и преимущества форматов float16 и float8, а также представим функции для их преобразования в стандартные float и double.

Это поможет разработчикам и исследователям лучше понять, как эффективно использовать эти форматы в своих проектах и моделях. В качестве примера мы рассмотрим работу ONNX-модели ESRGAN, которая применяется для улучшения качества изображений.


1. Новые типы данных для работы с ONNX-моделями

Для ускорения расчетов некоторые модели используют типы данных с меньшей точностью, такие как Float16 и даже Float8.

Для работы с ONNX-моделями в язык MQL5 добавлена поддержка новых типов данных, позволяющих работать с 8 и 16-битными представлениями чисел с плавающей точкой.

Скрипт выводит полный список элементов перечисления ENUM_ONNX_DATA_TYPE.

//+------------------------------------------------------------------+
//|                                              ONNX_Data_Types.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
void OnStart()
  {
//---
   for(int i=0; i<21; i++)
      PrintFormat("%2d %s",i,EnumToString(ENUM_ONNX_DATA_TYPE(i)));
  }

Результат:

 0: ONNX_DATA_TYPE_UNDEFINED
 1: ONNX_DATA_TYPE_FLOAT
 2: ONNX_DATA_TYPE_UINT8
 3: ONNX_DATA_TYPE_INT8
 4: ONNX_DATA_TYPE_UINT16
 5: ONNX_DATA_TYPE_INT16
 6: ONNX_DATA_TYPE_INT32
 7: ONNX_DATA_TYPE_INT64
 8: ONNX_DATA_TYPE_STRING
 9: ONNX_DATA_TYPE_BOOL
10: ONNX_DATA_TYPE_FLOAT16
11: ONNX_DATA_TYPE_DOUBLE
12: ONNX_DATA_TYPE_UINT32
13: ONNX_DATA_TYPE_UINT64
14: ONNX_DATA_TYPE_COMPLEX64
15: ONNX_DATA_TYPE_COMPLEX128
16: ONNX_DATA_TYPE_BFLOAT16
17: ONNX_DATA_TYPE_FLOAT8E4M3FN
18: ONNX_DATA_TYPE_FLOAT8E4M3FNUZ
19: ONNX_DATA_TYPE_FLOAT8E5M2
20: ONNX_DATA_TYPE_FLOAT8E5M2FNUZ

Таким образом, теперь можно исполнять ONNX-модели, работающие с такими данными.

Кроме того, в MQL5 появились дополнительные функции для конвертации данных:

bool ArrayToFP16(ushort &dst_array[],const float &src_array[],ENUM_FLOAT16_FORMAT fmt);
bool ArrayToFP16(ushort &dst_array[],const double &src_array[],ENUM_FLOAT16_FORMAT fmt);
bool ArrayToFP8(uchar &dst_array[],const float &src_array[],ENUM_FLOAT8_FORMAT fmt);
bool ArrayToFP8(uchar &dst_array[],const double &src_array[],ENUM_FLOAT8_FORMAT fmt);

bool ArrayFromFP16(float &dst_array[],const ushort &src_array[],ENUM_FLOAT16_FORMAT fmt);
bool ArrayFromFP16(double &dst_array[],const ushort &src_array[],ENUM_FLOAT16_FORMAT fmt);
bool ArrayFromFP8(float &dst_array[],const uchar &src_array[],ENUM_FLOAT8_FORMAT fmt);
bool ArrayFromFP8(double &dst_array[],const uchar &src_array[],ENUM_FLOAT8_FORMAT fmt);

Поскольку форматы вещественных чисел для 16 и 8 бит могут отличаться, в параметре fmt в функциях конверсии необходимо указывать, какой именно формат числа требуется обработать.

Для 16-битных версий используется новое перечисление ENUM_FLOAT16_FORMAT, которое на данный момент имеет следующие значения:

Для 8-битных версий используется новое перечисление ENUM_FLOAT8_FORMAT, которое на данный момент имеет следующие значения:


1.1. Формат FP16

Форматы FLOAT16 и BFLOAT16 представляют собой типы данных, используемые для представления чисел с плавающей точкой.

FLOAT16, также известный как половинная точность или формат "half-precision float", использует 16 бит для представления числа с плавающей точкой. Этот формат обеспечивает баланс между точностью и эффективностью вычислений. FLOAT16 широко применяется в глубоком обучении и нейронных сетях, где требуется высокая производительность при обработке больших объемов данных. Этот формат позволяет ускорить вычисления за счет сокращения размера чисел, что особенно важно при обучении глубоких нейронных сетей на графических процессорах (GPU).

BFLOAT16 (или Brain Floating Point 16) также использует 16 бит, но он различается от FLOAT16 в способе представления чисел. В этом формате 8 бит выделены для представления экспоненты, а оставшиеся 7 бит используются для представления мантиссы. Этот формат был разработан для использования в глубоком обучении и искусственном интеллекте, особенно в процессорах Google Tensor Processing Unit (TPU). BFLOAT16 обладает хорошей производительностью при обучении нейронных сетей и может быть эффективно использован для ускорения вычислений.

Оба этих формата имеют свои преимущества и ограничения. FLOAT16 обеспечивает более высокую точность, но требует больше ресурсов для хранения и вычислений. BFLOAT16, с другой стороны, обеспечивает более высокую производительность и эффективность при обработке данных, но может быть менее точным.


Рис. Форматы битового представления чисел с плавающей точкой FLOAT16 и BFLOAT16

Рис.1. Форматы битового представления чисел с плавающей точкой FLOAT16 и BFLOAT16


Рис. Детали формата FLOAT16

Табл.1. Числа с плавающей точкой в формате FLOAT16


1.1.1. Тесты исполнения ONNX-оператора Cast для FLOAT16

В качестве иллюстрации рассмотрим задачу преобразования данных типа FLOAT16 в типы float и double.

ONNX-модели с операцией Cast:

Параметры ONNX-модели test_cast_FLOAT16_to_DOUBLE.onnx

Рис.2. Входные и выходные параметры модели test_cast_FLOAT16_to_DOUBLE.onnx


 Рис. Входные и выходные параметры модели test_cast_FLOAT16_to_FLOAT.onnx

Рис.3. Входные и выходные параметры модели test_cast_FLOAT16_to_FLOAT.onnx


Как видно из описания свойств ONNX-моделей, на входе требуется данные типа  ONNX_DATA_TYPE_FLOAT16, выходные данные модель возвратит в формате ONNX_DATA_TYPE_FLOAT.

Для конвертации значений будем использовать функции ArrayToFP16() и ArrayFromFP16() с параметром FLOAT_FP16.

Пример:

//+------------------------------------------------------------------+
//|                                              TestCastFloat16.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"

#resource "models\\test_cast_FLOAT16_to_DOUBLE.onnx" as const uchar ExtModel1[];
#resource "models\\test_cast_FLOAT16_to_FLOAT.onnx" as const uchar ExtModel2[];

//+------------------------------------------------------------------+
//| union for data conversion                                        |
//+------------------------------------------------------------------+
template<typename T>
union U
  {
   uchar uc[sizeof(T)];
   T value;
  };
//+------------------------------------------------------------------+
//| ArrayToString                                                    |
//+------------------------------------------------------------------+
template<typename T>
string ArrayToString(const T &data[],uint length=16)
  {
   string res;

   for(uint n=0; n<MathMin(length,data.Size()); n++)
      res+="," + StringFormat("%.2x",data[n]);

   StringSetCharacter(res,0,'[');
   return res+"]";
  }

//+------------------------------------------------------------------+
//| PatchONNXModel                                                   |
//+------------------------------------------------------------------+
void PatchONNXModel(const uchar &original_model[],uchar &patched_model[])
  {
   ArrayCopy(patched_model,original_model,0,0,WHOLE_ARRAY);
//--- special ONNX model patch(IR=9,Opset=20)
   patched_model[1]=0x09;
   patched_model[ArraySize(patched_model)-1]=0x14;
  }
//+------------------------------------------------------------------+
//| CreateModel                                                      |
//+------------------------------------------------------------------+
bool CreateModel(long &model_handle,const uchar &model[])
  {
   model_handle=INVALID_HANDLE;
   ulong flags=ONNX_DEFAULT;
//ulong flags=ONNX_DEBUG_LOGS;
//---
   model_handle=OnnxCreateFromBuffer(model,flags);
   if(model_handle==INVALID_HANDLE)
      return(false);
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| PrepareShapes                                                    |
//+------------------------------------------------------------------+
bool PrepareShapes(long model_handle)
  {
   ulong input_shape1[]= {3,4};
   if(!OnnxSetInputShape(model_handle,0,input_shape1))
     {
      PrintFormat("error in OnnxSetInputShape for input1. error code=%d",GetLastError());
      //--
      OnnxRelease(model_handle);
      return(false);
     }
//---
   ulong output_shape[]= {3,4};
   if(!OnnxSetOutputShape(model_handle,0,output_shape))
     {
      PrintFormat("error in OnnxSetOutputShape for output. error code=%d",GetLastError());
      //--
      OnnxRelease(model_handle);
      return(false);
     }
//---
   return(true);
  }

//+------------------------------------------------------------------+
//| RunCastFloat16ToDouble                                           |
//+------------------------------------------------------------------+
bool RunCastFloat16ToDouble(long model_handle)
  {
   PrintFormat("test=%s",__FUNCTION__);

   double test_data[12]= {1,2,3,4,5,6,7,8,9,10,11,12};
   ushort data_uint16[12];
   if(!ArrayToFP16(data_uint16,test_data,FLOAT_FP16))
     {
      Print("error in ArrayToFP16. error code=",GetLastError());
      return(false);
     }
   Print("test array:");
   ArrayPrint(test_data);
   Print("ArrayToFP16:");
   ArrayPrint(data_uint16);

   U<ushort> input_float16_values[3*4];
   U<double> output_double_values[3*4];

   float test_data_float[];
   if(!ArrayFromFP16(test_data_float,data_uint16,FLOAT_FP16))
     {
      Print("error in ArrayFromFP16. error code=",GetLastError());
      return(false);
     }

   for(int i=0; i<12; i++)
     {
      input_float16_values[i].value=data_uint16[i];
      PrintFormat("%d input value =%f  Hex float16 = %s  ushort value=%d",i,test_data_float[i],ArrayToString(input_float16_values[i].uc),input_float16_values[i].value);
     }

   Print("ONNX input array:");
   ArrayPrint(input_float16_values);

   bool res=OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float16_values,output_double_values);
   if(!res)
     {
      PrintFormat("error in OnnxRun. error code=%d",GetLastError());
      return(false);
     }

   Print("ONNX output array:");
   ArrayPrint(output_double_values);
//---
   double sum_error=0.0;
   for(int i=0; i<12; i++)
     {
      double delta=test_data[i]-output_double_values[i].value;
      sum_error+=MathAbs(delta);
      PrintFormat("%d output double %f = %s  difference=%f",i,output_double_values[i].value,ArrayToString(output_double_values[i].uc),delta);
     }
//---
   PrintFormat("test=%s   sum_error=%f",__FUNCTION__,sum_error);
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| RunCastFloat16ToFloat                                            |
//+------------------------------------------------------------------+
bool RunCastFloat16ToFloat(long model_handle)
  {
   PrintFormat("test=%s",__FUNCTION__);

   double test_data[12]= {1,2,3,4,5,6,7,8,9,10,11,12};
   ushort data_uint16[12];
   if(!ArrayToFP16(data_uint16,test_data,FLOAT_FP16))
     {
      Print("error in ArrayToFP16. error code=",GetLastError());
      return(false);
     }
   Print("test array:");
   ArrayPrint(test_data);
   Print("ArrayToFP16:");
   ArrayPrint(data_uint16);

   U<ushort> input_float16_values[3*4];
   U<float>  output_float_values[3*4];

   float test_data_float[];
   if(!ArrayFromFP16(test_data_float,data_uint16,FLOAT_FP16))
     {
      Print("error in ArrayFromFP16. error code=",GetLastError());
      return(false);
     }

   for(int i=0; i<12; i++)
     {
      input_float16_values[i].value=data_uint16[i];
      PrintFormat("%d input value =%f  Hex float16 = %s  ushort value=%d",i,test_data_float[i],ArrayToString(input_float16_values[i].uc),input_float16_values[i].value);
     }

   Print("ONNX input array:");
   ArrayPrint(input_float16_values);

   bool res=OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float16_values,output_float_values);
   if(!res)
     {
      PrintFormat("error in OnnxRun. error code=%d",GetLastError());
      return(false);
     }

   Print("ONNX output array:");
   ArrayPrint(output_float_values);
//---
   double sum_error=0.0;
   for(int i=0; i<12; i++)
     {
      double delta=test_data[i]-(double)output_float_values[i].value;
      sum_error+=MathAbs(delta);
      PrintFormat("%d output float %f = %s difference=%f",i,output_float_values[i].value,ArrayToString(output_float_values[i].uc),delta);
     }
//---
   PrintFormat("test=%s   sum_error=%f",__FUNCTION__,sum_error);
//---
   return(true);
  }

//+------------------------------------------------------------------+
//| TestCastFloat16ToFloat                                           |
//+------------------------------------------------------------------+
bool TestCastFloat16ToFloat(const uchar &res_model[])
  {
   uchar model[];
   PatchONNXModel(res_model,model);
//--- get model handle
   long model_handle=INVALID_HANDLE;
//--- get model handle
   if(!CreateModel(model_handle,model))
      return(false);
//--- prepare input and output shapes
   if(!PrepareShapes(model_handle))
      return(false);
//--- run ONNX model
   if(!RunCastFloat16ToFloat(model_handle))
      return(false);
//--- release model handle
   OnnxRelease(model_handle);
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| TestCastFloat16ToDouble                                          |
//+------------------------------------------------------------------+
bool TestCastFloat16ToDouble(const uchar &res_model[])
  {
   uchar model[];
   PatchONNXModel(res_model,model);
//---
   long model_handle=INVALID_HANDLE;
//--- get model handle
   if(!CreateModel(model_handle,model))
      return(false);
//--- prepare input and output shapes
   if(!PrepareShapes(model_handle))
      return(false);
//--- run ONNX model
   if(!RunCastFloat16ToDouble(model_handle))
      return(false);
//--- release model handle
   OnnxRelease(model_handle);
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
int OnStart(void)
  {
   if(!TestCastFloat16ToDouble(ExtModel1))
      return 1;

   if(!TestCastFloat16ToFloat(ExtModel2))
      return 1;
//---
   return 0;
  }
//+------------------------------------------------------------------+

Результат:

TestCastFloat16 (EURUSD,H1)     test=RunCastFloat16ToDouble
TestCastFloat16 (EURUSD,H1)     test array:
TestCastFloat16 (EURUSD,H1)      1.00000  2.00000  3.00000  4.00000  5.00000  6.00000  7.00000  8.00000  9.00000 10.00000 11.00000 12.00000
TestCastFloat16 (EURUSD,H1)     ArrayToFP16:
TestCastFloat16 (EURUSD,H1)     15360 16384 16896 17408 17664 17920 18176 18432 18560 18688 18816 18944
TestCastFloat16 (EURUSD,H1)     0 input value =1.000000  Hex float16 = [00,3c]  ushort value=15360
TestCastFloat16 (EURUSD,H1)     1 input value =2.000000  Hex float16 = [00,40]  ushort value=16384
TestCastFloat16 (EURUSD,H1)     2 input value =3.000000  Hex float16 = [00,42]  ushort value=16896
TestCastFloat16 (EURUSD,H1)     3 input value =4.000000  Hex float16 = [00,44]  ushort value=17408
TestCastFloat16 (EURUSD,H1)     4 input value =5.000000  Hex float16 = [00,45]  ushort value=17664
TestCastFloat16 (EURUSD,H1)     5 input value =6.000000  Hex float16 = [00,46]  ushort value=17920
TestCastFloat16 (EURUSD,H1)     6 input value =7.000000  Hex float16 = [00,47]  ushort value=18176
TestCastFloat16 (EURUSD,H1)     7 input value =8.000000  Hex float16 = [00,48]  ushort value=18432
TestCastFloat16 (EURUSD,H1)     8 input value =9.000000  Hex float16 = [80,48]  ushort value=18560
TestCastFloat16 (EURUSD,H1)     9 input value =10.000000  Hex float16 = [00,49]  ushort value=18688
TestCastFloat16 (EURUSD,H1)     10 input value =11.000000  Hex float16 = [80,49]  ushort value=18816
TestCastFloat16 (EURUSD,H1)     11 input value =12.000000  Hex float16 = [00,4a]  ushort value=18944
TestCastFloat16 (EURUSD,H1)     ONNX input array:
TestCastFloat16 (EURUSD,H1)          [uc] [value]
TestCastFloat16 (EURUSD,H1)     [ 0]  ...   15360
TestCastFloat16 (EURUSD,H1)     [ 1]  ...   16384
TestCastFloat16 (EURUSD,H1)     [ 2]  ...   16896
TestCastFloat16 (EURUSD,H1)     [ 3]  ...   17408
TestCastFloat16 (EURUSD,H1)     [ 4]  ...   17664
TestCastFloat16 (EURUSD,H1)     [ 5]  ...   17920
TestCastFloat16 (EURUSD,H1)     [ 6]  ...   18176
TestCastFloat16 (EURUSD,H1)     [ 7]  ...   18432
TestCastFloat16 (EURUSD,H1)     [ 8]  ...   18560
TestCastFloat16 (EURUSD,H1)     [ 9]  ...   18688
TestCastFloat16 (EURUSD,H1)     [10]  ...   18816
TestCastFloat16 (EURUSD,H1)     [11]  ...   18944
TestCastFloat16 (EURUSD,H1)     ONNX output array:
TestCastFloat16 (EURUSD,H1)          [uc]  [value]
TestCastFloat16 (EURUSD,H1)     [ 0]  ...  1.00000
TestCastFloat16 (EURUSD,H1)     [ 1]  ...  2.00000
TestCastFloat16 (EURUSD,H1)     [ 2]  ...  3.00000
TestCastFloat16 (EURUSD,H1)     [ 3]  ...  4.00000
TestCastFloat16 (EURUSD,H1)     [ 4]  ...  5.00000
TestCastFloat16 (EURUSD,H1)     [ 5]  ...  6.00000
TestCastFloat16 (EURUSD,H1)     [ 6]  ...  7.00000
TestCastFloat16 (EURUSD,H1)     [ 7]  ...  8.00000
TestCastFloat16 (EURUSD,H1)     [ 8]  ...  9.00000
TestCastFloat16 (EURUSD,H1)     [ 9]  ... 10.00000
TestCastFloat16 (EURUSD,H1)     [10]  ... 11.00000
TestCastFloat16 (EURUSD,H1)     [11]  ... 12.00000
TestCastFloat16 (EURUSD,H1)     0 output double 1.000000 = [00,00,00,00,00,00,f0,3f]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     1 output double 2.000000 = [00,00,00,00,00,00,00,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     2 output double 3.000000 = [00,00,00,00,00,00,08,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     3 output double 4.000000 = [00,00,00,00,00,00,10,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     4 output double 5.000000 = [00,00,00,00,00,00,14,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     5 output double 6.000000 = [00,00,00,00,00,00,18,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     6 output double 7.000000 = [00,00,00,00,00,00,1c,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     7 output double 8.000000 = [00,00,00,00,00,00,20,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     8 output double 9.000000 = [00,00,00,00,00,00,22,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     9 output double 10.000000 = [00,00,00,00,00,00,24,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     10 output double 11.000000 = [00,00,00,00,00,00,26,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     11 output double 12.000000 = [00,00,00,00,00,00,28,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     test=RunCastFloat16ToDouble   sum_error=0.000000
TestCastFloat16 (EURUSD,H1)     test=RunCastFloat16ToFloat
TestCastFloat16 (EURUSD,H1)     test array:
TestCastFloat16 (EURUSD,H1)      1.00000  2.00000  3.00000  4.00000  5.00000  6.00000  7.00000  8.00000  9.00000 10.00000 11.00000 12.00000
TestCastFloat16 (EURUSD,H1)     ArrayToFP16:
TestCastFloat16 (EURUSD,H1)     15360 16384 16896 17408 17664 17920 18176 18432 18560 18688 18816 18944
TestCastFloat16 (EURUSD,H1)     0 input value =1.000000  Hex float16 = [00,3c]  ushort value=15360
TestCastFloat16 (EURUSD,H1)     1 input value =2.000000  Hex float16 = [00,40]  ushort value=16384
TestCastFloat16 (EURUSD,H1)     2 input value =3.000000  Hex float16 = [00,42]  ushort value=16896
TestCastFloat16 (EURUSD,H1)     3 input value =4.000000  Hex float16 = [00,44]  ushort value=17408
TestCastFloat16 (EURUSD,H1)     4 input value =5.000000  Hex float16 = [00,45]  ushort value=17664
TestCastFloat16 (EURUSD,H1)     5 input value =6.000000  Hex float16 = [00,46]  ushort value=17920
TestCastFloat16 (EURUSD,H1)     6 input value =7.000000  Hex float16 = [00,47]  ushort value=18176
TestCastFloat16 (EURUSD,H1)     7 input value =8.000000  Hex float16 = [00,48]  ushort value=18432
TestCastFloat16 (EURUSD,H1)     8 input value =9.000000  Hex float16 = [80,48]  ushort value=18560
TestCastFloat16 (EURUSD,H1)     9 input value =10.000000  Hex float16 = [00,49]  ushort value=18688
TestCastFloat16 (EURUSD,H1)     10 input value =11.000000  Hex float16 = [80,49]  ushort value=18816
TestCastFloat16 (EURUSD,H1)     11 input value =12.000000  Hex float16 = [00,4a]  ushort value=18944
TestCastFloat16 (EURUSD,H1)     ONNX input array:
TestCastFloat16 (EURUSD,H1)          [uc] [value]
TestCastFloat16 (EURUSD,H1)     [ 0]  ...   15360
TestCastFloat16 (EURUSD,H1)     [ 1]  ...   16384
TestCastFloat16 (EURUSD,H1)     [ 2]  ...   16896
TestCastFloat16 (EURUSD,H1)     [ 3]  ...   17408
TestCastFloat16 (EURUSD,H1)     [ 4]  ...   17664
TestCastFloat16 (EURUSD,H1)     [ 5]  ...   17920
TestCastFloat16 (EURUSD,H1)     [ 6]  ...   18176
TestCastFloat16 (EURUSD,H1)     [ 7]  ...   18432
TestCastFloat16 (EURUSD,H1)     [ 8]  ...   18560
TestCastFloat16 (EURUSD,H1)     [ 9]  ...   18688
TestCastFloat16 (EURUSD,H1)     [10]  ...   18816
TestCastFloat16 (EURUSD,H1)     [11]  ...   18944
TestCastFloat16 (EURUSD,H1)     ONNX output array:
TestCastFloat16 (EURUSD,H1)          [uc]  [value]
TestCastFloat16 (EURUSD,H1)     [ 0]  ...  1.00000
TestCastFloat16 (EURUSD,H1)     [ 1]  ...  2.00000
TestCastFloat16 (EURUSD,H1)     [ 2]  ...  3.00000
TestCastFloat16 (EURUSD,H1)     [ 3]  ...  4.00000
TestCastFloat16 (EURUSD,H1)     [ 4]  ...  5.00000
TestCastFloat16 (EURUSD,H1)     [ 5]  ...  6.00000
TestCastFloat16 (EURUSD,H1)     [ 6]  ...  7.00000
TestCastFloat16 (EURUSD,H1)     [ 7]  ...  8.00000
TestCastFloat16 (EURUSD,H1)     [ 8]  ...  9.00000
TestCastFloat16 (EURUSD,H1)     [ 9]  ... 10.00000
TestCastFloat16 (EURUSD,H1)     [10]  ... 11.00000
TestCastFloat16 (EURUSD,H1)     [11]  ... 12.00000
TestCastFloat16 (EURUSD,H1)     0 output float 1.000000 = [00,00,80,3f] difference=0.000000
TestCastFloat16 (EURUSD,H1)     1 output float 2.000000 = [00,00,00,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     2 output float 3.000000 = [00,00,40,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     3 output float 4.000000 = [00,00,80,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     4 output float 5.000000 = [00,00,a0,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     5 output float 6.000000 = [00,00,c0,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     6 output float 7.000000 = [00,00,e0,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     7 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastFloat16 (EURUSD,H1)     8 output float 9.000000 = [00,00,10,41] difference=0.000000
TestCastFloat16 (EURUSD,H1)     9 output float 10.000000 = [00,00,20,41] difference=0.000000
TestCastFloat16 (EURUSD,H1)     10 output float 11.000000 = [00,00,30,41] difference=0.000000
TestCastFloat16 (EURUSD,H1)     11 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat16 (EURUSD,H1)     test=RunCastFloat16ToFloat   sum_error=0.000000


1.1.2. Тесты исполнения ONNX-оператора Cast для BFLOAT16

В этом примере рассматривается преобразование из типа BFLOAT16 в float.

ONNX-модель с операцией Cast:


Рис. Входные и выходные параметры модели test_cast_BFLOAT16_to_FLOAT.onnx

Рис.4. Входные и выходные параметры модели test_cast_BFLOAT16_to_FLOAT.onnx

На входе требуется данные типа  ONNX_DATA_TYPE_BFLOAT16, выходные данные модель возвратит в формате ONNX_DATA_TYPE_FLOAT.

Для конвертации значений будем использовать функции ArrayToFP16() и ArrayFromFP16() с параметром BFLOAT_FP16.

//+------------------------------------------------------------------+
//|                                             TestCastBFloat16.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"

#resource "models\\test_cast_BFLOAT16_to_FLOAT.onnx" as const uchar ExtModel1[];

//+------------------------------------------------------------------+
//| union for data conversion                                        |
//+------------------------------------------------------------------+
template<typename T>
union U
  {
   uchar uc[sizeof(T)];
   T value;
  };
//+------------------------------------------------------------------+
//| ArrayToString                                                    |
//+------------------------------------------------------------------+
template<typename T>
string ArrayToString(const T &data[],uint length=16)
  {
   string res;

   for(uint n=0; n<MathMin(length,data.Size()); n++)
      res+="," + StringFormat("%.2x",data[n]);

   StringSetCharacter(res,0,'[');
   return res+"]";
  }

//+------------------------------------------------------------------+
//| PatchONNXModel                                                   |
//+------------------------------------------------------------------+
void PatchONNXModel(const uchar &original_model[],uchar &patched_model[])
  {
   ArrayCopy(patched_model,original_model,0,0,WHOLE_ARRAY);
//--- special ONNX model patch(IR=9,Opset=20)
   patched_model[1]=0x09;
   patched_model[ArraySize(patched_model)-1]=0x14;
  }
//+------------------------------------------------------------------+
//| CreateModel                                                      |
//+------------------------------------------------------------------+
bool CreateModel(long &model_handle,const uchar &model[])
  {
   model_handle=INVALID_HANDLE;
   ulong flags=ONNX_DEFAULT;
//ulong flags=ONNX_DEBUG_LOGS;
//---
   model_handle=OnnxCreateFromBuffer(model,flags);
   if(model_handle==INVALID_HANDLE)
      return(false);
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| PrepareShapes                                                    |
//+------------------------------------------------------------------+
bool PrepareShapes(long model_handle)
  {
   ulong input_shape1[]= {3,4};
   if(!OnnxSetInputShape(model_handle,0,input_shape1))
     {
      PrintFormat("error in OnnxSetInputShape for input1. error code=%d",GetLastError());
      //--
      OnnxRelease(model_handle);
      return(false);
     }
//---
   ulong output_shape[]= {3,4};
   if(!OnnxSetOutputShape(model_handle,0,output_shape))
     {
      PrintFormat("error in OnnxSetOutputShape for output. error code=%d",GetLastError());
      //--
      OnnxRelease(model_handle);
      return(false);
     }
//---
   return(true);
  }

//+------------------------------------------------------------------+
//| RunCastBFloat16ToFloat                                           |
//+------------------------------------------------------------------+
bool RunCastBFloat16ToFloat(long model_handle)
  {
   PrintFormat("test=%s",__FUNCTION__);

   double test_data[12]= {1,2,3,4,5,6,7,8,9,10,11,12};
   ushort data_uint16[12];
   if(!ArrayToFP16(data_uint16,test_data,FLOAT_BFP16))
     {
      Print("error in ArrayToFP16. error code=",GetLastError());
      return(false);
     }
   Print("test array:");
   ArrayPrint(test_data);
   Print("ArrayToFP16:");
   ArrayPrint(data_uint16);

   U<ushort> input_float16_values[3*4];
   U<float>  output_float_values[3*4];

   float test_data_float[];
   if(!ArrayFromFP16(test_data_float,data_uint16,FLOAT_BFP16))
     {
      Print("error in ArrayFromFP16. error code=",GetLastError());
      return(false);
     }

   for(int i=0; i<12; i++)
     {
      input_float16_values[i].value=data_uint16[i];
      PrintFormat("%d input value =%f  Hex float16 = %s  ushort value=%d",i,test_data_float[i],ArrayToString(input_float16_values[i].uc),input_float16_values[i].value);
     }

   Print("ONNX input array:");
   ArrayPrint(input_float16_values);

   bool res=OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float16_values,output_float_values);
   if(!res)
     {
      PrintFormat("error in OnnxRun. error code=%d",GetLastError());
      return(false);
     }

   Print("ONNX output array:");
   ArrayPrint(output_float_values);
//---
   double sum_error=0.0;
   for(int i=0; i<12; i++)
     {
      double delta=test_data[i]-(double)output_float_values[i].value;
      sum_error+=MathAbs(delta);
      PrintFormat("%d output float %f = %s difference=%f",i,output_float_values[i].value,ArrayToString(output_float_values[i].uc),delta);
     }
//---
   PrintFormat("test=%s   sum_error=%f",__FUNCTION__,sum_error);
//---
   return(true);
  }

//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
int OnStart(void)
  {
   uchar model[];
   PatchONNXModel(ExtModel1,model);
//--- get model handle
   long model_handle=INVALID_HANDLE;
//--- get model handle
   if(!CreateModel(model_handle,model))
      return 1;
//--- prepare input and output shapes
   if(!PrepareShapes(model_handle))
      return 1;
//--- run ONNX model
   if(!RunCastBFloat16ToFloat(model_handle))
      return 1;
//--- release model handle
   OnnxRelease(model_handle);
//---
   return 0;
  }
//+------------------------------------------------------------------+
Результат:
TestCastBFloat16 (EURUSD,H1)    test=RunCastBFloat16ToFloat
TestCastBFloat16 (EURUSD,H1)    test array:
TestCastBFloat16 (EURUSD,H1)     1.00000  2.00000  3.00000  4.00000  5.00000  6.00000  7.00000  8.00000  9.00000 10.00000 11.00000 12.00000
TestCastBFloat16 (EURUSD,H1)    ArrayToFP16:
TestCastBFloat16 (EURUSD,H1)    16256 16384 16448 16512 16544 16576 16608 16640 16656 16672 16688 16704
TestCastBFloat16 (EURUSD,H1)    0 input value =1.000000  Hex float16 = [80,3f]  ushort value=16256
TestCastBFloat16 (EURUSD,H1)    1 input value =2.000000  Hex float16 = [00,40]  ushort value=16384
TestCastBFloat16 (EURUSD,H1)    2 input value =3.000000  Hex float16 = [40,40]  ushort value=16448
TestCastBFloat16 (EURUSD,H1)    3 input value =4.000000  Hex float16 = [80,40]  ushort value=16512
TestCastBFloat16 (EURUSD,H1)    4 input value =5.000000  Hex float16 = [a0,40]  ushort value=16544
TestCastBFloat16 (EURUSD,H1)    5 input value =6.000000  Hex float16 = [c0,40]  ushort value=16576
TestCastBFloat16 (EURUSD,H1)    6 input value =7.000000  Hex float16 = [e0,40]  ushort value=16608
TestCastBFloat16 (EURUSD,H1)    7 input value =8.000000  Hex float16 = [00,41]  ushort value=16640
TestCastBFloat16 (EURUSD,H1)    8 input value =9.000000  Hex float16 = [10,41]  ushort value=16656
TestCastBFloat16 (EURUSD,H1)    9 input value =10.000000  Hex float16 = [20,41]  ushort value=16672
TestCastBFloat16 (EURUSD,H1)    10 input value =11.000000  Hex float16 = [30,41]  ushort value=16688
TestCastBFloat16 (EURUSD,H1)    11 input value =12.000000  Hex float16 = [40,41]  ushort value=16704
TestCastBFloat16 (EURUSD,H1)    ONNX input array:
TestCastBFloat16 (EURUSD,H1)         [uc] [value]
TestCastBFloat16 (EURUSD,H1)    [ 0]  ...   16256
TestCastBFloat16 (EURUSD,H1)    [ 1]  ...   16384
TestCastBFloat16 (EURUSD,H1)    [ 2]  ...   16448
TestCastBFloat16 (EURUSD,H1)    [ 3]  ...   16512
TestCastBFloat16 (EURUSD,H1)    [ 4]  ...   16544
TestCastBFloat16 (EURUSD,H1)    [ 5]  ...   16576
TestCastBFloat16 (EURUSD,H1)    [ 6]  ...   16608
TestCastBFloat16 (EURUSD,H1)    [ 7]  ...   16640
TestCastBFloat16 (EURUSD,H1)    [ 8]  ...   16656
TestCastBFloat16 (EURUSD,H1)    [ 9]  ...   16672
TestCastBFloat16 (EURUSD,H1)    [10]  ...   16688
TestCastBFloat16 (EURUSD,H1)    [11]  ...   16704
TestCastBFloat16 (EURUSD,H1)    ONNX output array:
TestCastBFloat16 (EURUSD,H1)         [uc]  [value]
TestCastBFloat16 (EURUSD,H1)    [ 0]  ...  1.00000
TestCastBFloat16 (EURUSD,H1)    [ 1]  ...  2.00000
TestCastBFloat16 (EURUSD,H1)    [ 2]  ...  3.00000
TestCastBFloat16 (EURUSD,H1)    [ 3]  ...  4.00000
TestCastBFloat16 (EURUSD,H1)    [ 4]  ...  5.00000
TestCastBFloat16 (EURUSD,H1)    [ 5]  ...  6.00000
TestCastBFloat16 (EURUSD,H1)    [ 6]  ...  7.00000
TestCastBFloat16 (EURUSD,H1)    [ 7]  ...  8.00000
TestCastBFloat16 (EURUSD,H1)    [ 8]  ...  9.00000
TestCastBFloat16 (EURUSD,H1)    [ 9]  ... 10.00000
TestCastBFloat16 (EURUSD,H1)    [10]  ... 11.00000
TestCastBFloat16 (EURUSD,H1)    [11]  ... 12.00000
TestCastBFloat16 (EURUSD,H1)    0 output float 1.000000 = [00,00,80,3f] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    1 output float 2.000000 = [00,00,00,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    2 output float 3.000000 = [00,00,40,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    3 output float 4.000000 = [00,00,80,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    4 output float 5.000000 = [00,00,a0,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    5 output float 6.000000 = [00,00,c0,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    6 output float 7.000000 = [00,00,e0,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    7 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    8 output float 9.000000 = [00,00,10,41] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    9 output float 10.000000 = [00,00,20,41] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    10 output float 11.000000 = [00,00,30,41] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    11 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    test=RunCastBFloat16ToFloat   sum_error=0.000000


1.2. Формат FP8

Современные языковые модели могут содержать миллиарды параметров. Обучение моделей с использованием чисел FP16 уже показало свою эффективность. Переход от 16-битного числа с плавающей точкой к FP8 позволяет вдвое уменьшить требования к памяти а также ускорить обучение и исполнение моделей.

Формат FP8 (8-битное число с плавающей запятой) представляет собой один из типов данных, используемых для представления чисел с плавающей точкой. В FP8 каждое число представлено 8 битами данных, которые обычно разделяются на три компонента: знак, экспоненту и мантиссу. Этот формат обеспечивает компромисс между точностью и эффективностью хранения данных, что делает его привлекательным для использования в приложениях, где требуется экономия памяти и вычислительных ресурсов. 

Одним из ключевых преимуществ FP8 является его эффективность при обработке больших объемов данных. Благодаря компактному представлению чисел, FP8 позволяет уменьшить требования к памяти и ускорить вычисления. Это особенно важно в приложениях машинного обучения и искусственного интеллекта, где обработка больших наборов данных является обычным делом.

Кроме того, FP8 может быть полезен для реализации низкоуровневых операций, таких как арифметические вычисления и обработка сигналов. Его компактный формат делает его подходящим для использования во встраиваемых системах и приложениях, где ограничены ресурсы. Однако, следует отметить, что FP8 имеет свои ограничения, связанные с его ограниченной точностью. В некоторых приложениях, где требуется высокая точность вычислений, таких как научные вычисления или финансовая аналитика, использование FP8 может быть недостаточным.


1.2.1. Форматы fp8_e5m2 и fp8_e4m3

В 2022 году было опубликовано две статьи, вводящие числа с плавающей точкой, хранящиеся в одном байте, в отличие от чисел float32, хранящихся в 4 байтах.

В статье FP8 Formats for Deep Learning (2022) от NVIDIA, Intel и ARM вводятся два типа, следуя спецификациям IEEE. Первый тип - это E4M3, 1 бит для знака, 4 бита для экспоненты и 3 бита для мантиссы. Второй тип - это E5M2, 1 бит для знака, 5 бит для экспоненты и 2 для мантиссы.Первый тип обычно используется для весов, второй - для градиентов.

Вторая статья "8-bit Numerical Formats For Deep Neural Networks" представляет схожие типы. Стандарт IEEE присваивает одинаковое значение +0 (или целое число 0) и -0 (или целое число 128).  В статье предлагается присвоить различные значения float этим двум числам. Кроме того, также исследуются различные разделения между экспонентой и мантиссой, и показано, что E4M3 и E5M2 являются лучшими.

В результате в ONNX (с версии 1.15.0) было введено 4 новых типа:

  • E4M3FN: 1 бит для знака, 4 бита для экспоненты, 3 бита для мантиссы, только значения NaN и нет бесконечных значений (FN),
  • E4M3FNUZ: 1 бит для знака, 4 бита для экспоненты, 3 бита для мантиссы, только значения NaN и нет бесконечных значений (FN), нет отрицательного нуля (UZ)
  • E5M2: 1 бит для знака, 5 бит для экспоненты, 2 бита для мантиссы,
  • E5M2FNUZ: 1 бит для знака, 5 бит для экспоненты, 2 бита для мантиссы, только значения NaN и нет бесконечных значений (FN), нет отрицательного нуля (UZ)

Реализация обычно зависит от аппаратных средств. NVIDIA, Intel и Arm реализуют E4M3FN, а E5M2 реализованы в современных графических процессорах. GraphCore делает то же самое только с E4M3FNUZ и E5M2FNUZ.

Приведем кратко основные сведения о типе FP8 согласно статье NVIDIA Hopper: H100 and FP8 Support.


Рис. Формат битового представления чисел с плавающей точкой FP8_E4M3

Рис.5. Формат битового представления чисел с плавающей точкой FP8_E4M3


Детали формата FP8_E5M2

Табл.3. Числа с плавающей точкой в формате E5M2


Детали формата FP8_E5M2

Табл.4. Числа с плавающей точкой в формате E4M3

Сравнение дипазонов положительных значений чисел FP8_E4M3 и FP8_E5M2 приведены на рис

 Сравнение диапазонов положительных значений чисел FP8

Рис.6. Сравнение диапазонов положительных значений чисел FP8 (источник)


Сравнение точности выполнения арифметических операций (Add, Mul, Div) для чисел в форматах FP8_E5M2  и FP8_E4M3 приведены на рис:

Сравнение точности арифметических операций для чисел в форматах float8_e5m2 и float8_e4m3 (источник)

Рис.7. Сравнение точности арифметических операций для чисел в форматах float8_e5m2 и float8_e4m3 (источник)

Рекомендуемое использование чисел в формате FP8:


1.2.2. Тесты исполнения ONNX-оператора Cast для FLOAT8

В этом примере рассматривается преобразование из различных типов FLOAT8 в float.

ONNX-модели с операцией Cast:


Рис.8. Входные и выходные параметры модели test_cast_FLOAT8E4M3FN_to_FLOAT.onnx в MetaEditor

Рис.8. Входные и выходные параметры модели test_cast_FLOAT8E4M3FN_to_FLOAT.onnx в MetaEditor



 Рис.9. Входные и выходные параметры модели test_cast_FLOAT8E4M3FNUZ_to_FLOAT.onnx в MetaEditor

Рис.9. Входные и выходные параметры модели test_cast_FLOAT8E4M3FNUZ_to_FLOAT.onnx в MetaEditor


Рис.10. Входные и выходные параметры модели test_cast_FLOAT8E5M2_to_FLOAT.onnx в MetaEditor

Рис.10. Входные и выходные параметры модели test_cast_FLOAT8E5M2_to_FLOAT.onnx в MetaEditor


 Рис.11. Входные и выходные параметры модели test_cast_FLOAT8E5M2FNUZ_to_FLOAT.onnx в MetaEditor

Рис.11. Входные и выходные параметры модели test_cast_FLOAT8E5M2FNUZ_to_FLOAT.onnx в MetaEditor


Пример:

//+------------------------------------------------------------------+
//|                                              TestCastBFloat8.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"

#resource "models\\test_cast_FLOAT8E4M3FN_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E4M3FN_to_FLOAT[];
#resource "models\\test_cast_FLOAT8E4M3FNUZ_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E4M3FNUZ_to_FLOAT[];
#resource "models\\test_cast_FLOAT8E5M2_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E5M2_to_FLOAT[];
#resource "models\\test_cast_FLOAT8E5M2FNUZ_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E5M2FNUZ_to_FLOAT[];

#define TEST_PASSED 0
#define TEST_FAILED 1
//+------------------------------------------------------------------+
//| union for data conversion                                        |
//+------------------------------------------------------------------+
template<typename T>
union U
  {
   uchar uc[sizeof(T)];
   T value;
  };
//+------------------------------------------------------------------+
//| ArrayToHexString                                                 |
//+------------------------------------------------------------------+
template<typename T>
string ArrayToHexString(const T &data[],uint length=16)
  {
   string res;

   for(uint n=0; n<MathMin(length,data.Size()); n++)
      res+="," + StringFormat("%.2x",data[n]);

   StringSetCharacter(res,0,'[');
   return(res+"]");
  }
//+------------------------------------------------------------------+
//| ArrayToString                                                    |
//+------------------------------------------------------------------+
template<typename T>
string ArrayToString(const U<T> &data[],uint length=16)
  {
   string res;

   for(uint n=0; n<MathMin(length,data.Size()); n++)
      res+="," + (string)data[n].value;

   StringSetCharacter(res,0,'[');
   return(res+"]");
  }
//+------------------------------------------------------------------+
//| PatchONNXModel                                                   |
//+------------------------------------------------------------------+
long CreatePatchedModel(const uchar &original_model[])
  {
   uchar patched_model[];
   ArrayCopy(patched_model,original_model);
//--- special ONNX model patch(IR=9,Opset=20)
   patched_model[1]=0x09;
   patched_model[ArraySize(patched_model)-1]=0x14;

   return(OnnxCreateFromBuffer(patched_model,ONNX_DEFAULT));
  }
//+------------------------------------------------------------------+
//| PrepareShapes                                                    |
//+------------------------------------------------------------------+
bool PrepareShapes(long model_handle)
  {
//--- configure input shape
   ulong input_shape[]= {3,5};

   if(!OnnxSetInputShape(model_handle,0,input_shape))
     {
      PrintFormat("error in OnnxSetInputShape for input1. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(false);
     }
//--- configure output shape
   ulong output_shape[]= {3,5};

   if(!OnnxSetOutputShape(model_handle,0,output_shape))
     {
      PrintFormat("error in OnnxSetOutputShape for output. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(false);
     }

   return(true);
  }
//+------------------------------------------------------------------+
//| RunCastFloat8Float                                               |
//+------------------------------------------------------------------+
bool RunCastFloat8ToFloat(long model_handle,const ENUM_FLOAT8_FORMAT fmt)
  {
   PrintFormat("TEST: %s(%s)",__FUNCTION__,EnumToString(fmt));
//---
   float test_data[15]   = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15};
   uchar data_float8[15] = {};

   if(!ArrayToFP8(data_float8,test_data,fmt))
     {
      Print("error in ArrayToFP8. error code=",GetLastError());
      OnnxRelease(model_handle);
      return(false);
     }

   U<uchar> input_float8_values[3*5];
   U<float> output_float_values[3*5];
   float    test_data_float[];
//--- convert float8 to float
   if(!ArrayFromFP8(test_data_float,data_float8,fmt))
     {
      Print("error in ArrayFromFP8. error code=",GetLastError());
      OnnxRelease(model_handle);
      return(false);
     }

   for(uint i=0; i<data_float8.Size(); i++)
     {
      input_float8_values[i].value=data_float8[i];
      PrintFormat("%d input value =%f  Hex float8 = %s  ushort value=%d",i,test_data_float[i],ArrayToHexString(input_float8_values[i].uc),input_float8_values[i].value);
     }

   Print("ONNX input array: ",ArrayToString(input_float8_values));
//--- execute model (convert float8 to float using ONNX)
   if(!OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float8_values,output_float_values))
     {
      PrintFormat("error in OnnxRun. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(false);
     }

   Print("ONNX output array: ",ArrayToString(output_float_values));
//--- calculate error (compare ONNX and ArrayFromFP8 results)
   double sum_error=0.0;

   for(uint i=0; i<test_data.Size(); i++)
     {
      double delta=test_data_float[i]-(double)output_float_values[i].value;
      sum_error+=MathAbs(delta);
      PrintFormat("%d output float %f = %s difference=%f",i,output_float_values[i].value,ArrayToHexString(output_float_values[i].uc),delta);
     }
//---
   PrintFormat("%s(%s): sum_error=%f\n",__FUNCTION__,EnumToString(fmt),sum_error);
   return(true);
  }
//+------------------------------------------------------------------+
//| TestModel                                                        |
//+------------------------------------------------------------------+
bool TestModel(const uchar &model[],const ENUM_FLOAT8_FORMAT fmt)
  {
//--- create patched model
   long model_handle=CreatePatchedModel(model);

   if(model_handle==INVALID_HANDLE)
      return(false);
//--- prepare input and output shapes
   if(!PrepareShapes(model_handle))
      return(false);
//--- run ONNX model
   if(!RunCastFloat8ToFloat(model_handle,fmt))
      return(false);
//--- release model handle
   OnnxRelease(model_handle);

   return(true);
  }
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
int OnStart(void)
  {
//--- run ONNX model
   if(!TestModel(ExtModel_FLOAT8E4M3FN_to_FLOAT,FLOAT_FP8_E4M3FN))
      return(TEST_FAILED);

//--- run ONNX model
   if(!TestModel(ExtModel_FLOAT8E4M3FNUZ_to_FLOAT,FLOAT_FP8_E4M3FNUZ))
      return(TEST_FAILED);

//--- run ONNX model
   if(!TestModel(ExtModel_FLOAT8E5M2_to_FLOAT,FLOAT_FP8_E5M2FN))
      return(TEST_FAILED);

//--- run ONNX model
   if(!TestModel(ExtModel_FLOAT8E5M2FNUZ_to_FLOAT,FLOAT_FP8_E5M2FNUZ))
      return(TEST_FAILED);

   return(TEST_PASSED);
  }
//+------------------------------------------------------------------+

Результат:

TestCastFloat8 (EURUSD,H1)      TEST: RunCastFloat8ToFloat(FLOAT_FP8_E4M3FN)
TestCastFloat8 (EURUSD,H1)      0 input value =1.000000  Hex float8 = [38]  ushort value=56
TestCastFloat8 (EURUSD,H1)      1 input value =2.000000  Hex float8 = [40]  ushort value=64
TestCastFloat8 (EURUSD,H1)      2 input value =3.000000  Hex float8 = [44]  ushort value=68
TestCastFloat8 (EURUSD,H1)      3 input value =4.000000  Hex float8 = [48]  ushort value=72
TestCastFloat8 (EURUSD,H1)      4 input value =5.000000  Hex float8 = [4a]  ushort value=74
TestCastFloat8 (EURUSD,H1)      5 input value =6.000000  Hex float8 = [4c]  ushort value=76
TestCastFloat8 (EURUSD,H1)      6 input value =7.000000  Hex float8 = [4e]  ushort value=78
TestCastFloat8 (EURUSD,H1)      7 input value =8.000000  Hex float8 = [50]  ushort value=80
TestCastFloat8 (EURUSD,H1)      8 input value =9.000000  Hex float8 = [51]  ushort value=81
TestCastFloat8 (EURUSD,H1)      9 input value =10.000000  Hex float8 = [52]  ushort value=82
TestCastFloat8 (EURUSD,H1)      10 input value =11.000000  Hex float8 = [53]  ushort value=83
TestCastFloat8 (EURUSD,H1)      11 input value =12.000000  Hex float8 = [54]  ushort value=84
TestCastFloat8 (EURUSD,H1)      12 input value =13.000000  Hex float8 = [55]  ushort value=85
TestCastFloat8 (EURUSD,H1)      13 input value =14.000000  Hex float8 = [56]  ushort value=86
TestCastFloat8 (EURUSD,H1)      14 input value =15.000000  Hex float8 = [57]  ushort value=87
TestCastFloat8 (EURUSD,H1)      ONNX input array: [56,64,68,72,74,76,78,80,81,82,83,84,85,86,87]
TestCastFloat8 (EURUSD,H1)      ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0]
TestCastFloat8 (EURUSD,H1)      0 output float 1.000000 = [00,00,80,3f] difference=0.000000
TestCastFloat8 (EURUSD,H1)      1 output float 2.000000 = [00,00,00,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      2 output float 3.000000 = [00,00,40,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      3 output float 4.000000 = [00,00,80,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      4 output float 5.000000 = [00,00,a0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      5 output float 6.000000 = [00,00,c0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      6 output float 7.000000 = [00,00,e0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      7 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      8 output float 9.000000 = [00,00,10,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      9 output float 10.000000 = [00,00,20,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      10 output float 11.000000 = [00,00,30,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      11 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      12 output float 13.000000 = [00,00,50,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      13 output float 14.000000 = [00,00,60,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      14 output float 15.000000 = [00,00,70,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      RunCastFloat8ToFloat(FLOAT_FP8_E4M3FN): sum_error=0.000000
TestCastFloat8 (EURUSD,H1)      
TestCastFloat8 (EURUSD,H1)      TEST: RunCastFloat8ToFloat(FLOAT_FP8_E4M3FNUZ)
TestCastFloat8 (EURUSD,H1)      0 input value =1.000000  Hex float8 = [40]  ushort value=64
TestCastFloat8 (EURUSD,H1)      1 input value =2.000000  Hex float8 = [48]  ushort value=72
TestCastFloat8 (EURUSD,H1)      2 input value =3.000000  Hex float8 = [4c]  ushort value=76
TestCastFloat8 (EURUSD,H1)      3 input value =4.000000  Hex float8 = [50]  ushort value=80
TestCastFloat8 (EURUSD,H1)      4 input value =5.000000  Hex float8 = [52]  ushort value=82
TestCastFloat8 (EURUSD,H1)      5 input value =6.000000  Hex float8 = [54]  ushort value=84
TestCastFloat8 (EURUSD,H1)      6 input value =7.000000  Hex float8 = [56]  ushort value=86
TestCastFloat8 (EURUSD,H1)      7 input value =8.000000  Hex float8 = [58]  ushort value=88
TestCastFloat8 (EURUSD,H1)      8 input value =9.000000  Hex float8 = [59]  ushort value=89
TestCastFloat8 (EURUSD,H1)      9 input value =10.000000  Hex float8 = [5a]  ushort value=90
TestCastFloat8 (EURUSD,H1)      10 input value =11.000000  Hex float8 = [5b]  ushort value=91
TestCastFloat8 (EURUSD,H1)      11 input value =12.000000  Hex float8 = [5c]  ushort value=92
TestCastFloat8 (EURUSD,H1)      12 input value =13.000000  Hex float8 = [5d]  ushort value=93
TestCastFloat8 (EURUSD,H1)      13 input value =14.000000  Hex float8 = [5e]  ushort value=94
TestCastFloat8 (EURUSD,H1)      14 input value =15.000000  Hex float8 = [5f]  ushort value=95
TestCastFloat8 (EURUSD,H1)      ONNX input array: [64,72,76,80,82,84,86,88,89,90,91,92,93,94,95]
TestCastFloat8 (EURUSD,H1)      ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0]
TestCastFloat8 (EURUSD,H1)      0 output float 1.000000 = [00,00,80,3f] difference=0.000000
TestCastFloat8 (EURUSD,H1)      1 output float 2.000000 = [00,00,00,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      2 output float 3.000000 = [00,00,40,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      3 output float 4.000000 = [00,00,80,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      4 output float 5.000000 = [00,00,a0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      5 output float 6.000000 = [00,00,c0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      6 output float 7.000000 = [00,00,e0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      7 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      8 output float 9.000000 = [00,00,10,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      9 output float 10.000000 = [00,00,20,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      10 output float 11.000000 = [00,00,30,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      11 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      12 output float 13.000000 = [00,00,50,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      13 output float 14.000000 = [00,00,60,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      14 output float 15.000000 = [00,00,70,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      RunCastFloat8ToFloat(FLOAT_FP8_E4M3FNUZ): sum_error=0.000000
TestCastFloat8 (EURUSD,H1)      
TestCastFloat8 (EURUSD,H1)      TEST: RunCastFloat8ToFloat(FLOAT_FP8_E5M2FN)
TestCastFloat8 (EURUSD,H1)      0 input value =1.000000  Hex float8 = [3c]  ushort value=60
TestCastFloat8 (EURUSD,H1)      1 input value =2.000000  Hex float8 = [40]  ushort value=64
TestCastFloat8 (EURUSD,H1)      2 input value =3.000000  Hex float8 = [42]  ushort value=66
TestCastFloat8 (EURUSD,H1)      3 input value =4.000000  Hex float8 = [44]  ushort value=68
TestCastFloat8 (EURUSD,H1)      4 input value =5.000000  Hex float8 = [45]  ushort value=69
TestCastFloat8 (EURUSD,H1)      5 input value =6.000000  Hex float8 = [46]  ushort value=70
TestCastFloat8 (EURUSD,H1)      6 input value =7.000000  Hex float8 = [47]  ushort value=71
TestCastFloat8 (EURUSD,H1)      7 input value =8.000000  Hex float8 = [48]  ushort value=72
TestCastFloat8 (EURUSD,H1)      8 input value =8.000000  Hex float8 = [48]  ushort value=72
TestCastFloat8 (EURUSD,H1)      9 input value =10.000000  Hex float8 = [49]  ushort value=73
TestCastFloat8 (EURUSD,H1)      10 input value =12.000000  Hex float8 = [4a]  ushort value=74
TestCastFloat8 (EURUSD,H1)      11 input value =12.000000  Hex float8 = [4a]  ushort value=74
TestCastFloat8 (EURUSD,H1)      12 input value =12.000000  Hex float8 = [4a]  ushort value=74
TestCastFloat8 (EURUSD,H1)      13 input value =14.000000  Hex float8 = [4b]  ushort value=75
TestCastFloat8 (EURUSD,H1)      14 input value =16.000000  Hex float8 = [4c]  ushort value=76
TestCastFloat8 (EURUSD,H1)      ONNX input array: [60,64,66,68,69,70,71,72,72,73,74,74,74,75,76]
TestCastFloat8 (EURUSD,H1)      ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,8.0,10.0,12.0,12.0,12.0,14.0,16.0]
TestCastFloat8 (EURUSD,H1)      0 output float 1.000000 = [00,00,80,3f] difference=0.000000
TestCastFloat8 (EURUSD,H1)      1 output float 2.000000 = [00,00,00,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      2 output float 3.000000 = [00,00,40,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      3 output float 4.000000 = [00,00,80,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      4 output float 5.000000 = [00,00,a0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      5 output float 6.000000 = [00,00,c0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      6 output float 7.000000 = [00,00,e0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      7 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      8 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      9 output float 10.000000 = [00,00,20,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      10 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      11 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      12 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      13 output float 14.000000 = [00,00,60,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      14 output float 16.000000 = [00,00,80,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      RunCastFloat8ToFloat(FLOAT_FP8_E5M2FN): sum_error=0.000000
TestCastFloat8 (EURUSD,H1)      
TestCastFloat8 (EURUSD,H1)      TEST: RunCastFloat8ToFloat(FLOAT_FP8_E5M2FNUZ)
TestCastFloat8 (EURUSD,H1)      0 input value =1.000000  Hex float8 = [40]  ushort value=64
TestCastFloat8 (EURUSD,H1)      1 input value =2.000000  Hex float8 = [44]  ushort value=68
TestCastFloat8 (EURUSD,H1)      2 input value =3.000000  Hex float8 = [46]  ushort value=70
TestCastFloat8 (EURUSD,H1)      3 input value =4.000000  Hex float8 = [48]  ushort value=72
TestCastFloat8 (EURUSD,H1)      4 input value =5.000000  Hex float8 = [49]  ushort value=73
TestCastFloat8 (EURUSD,H1)      5 input value =6.000000  Hex float8 = [4a]  ushort value=74
TestCastFloat8 (EURUSD,H1)      6 input value =7.000000  Hex float8 = [4b]  ushort value=75
TestCastFloat8 (EURUSD,H1)      7 input value =8.000000  Hex float8 = [4c]  ushort value=76
TestCastFloat8 (EURUSD,H1)      8 input value =8.000000  Hex float8 = [4c]  ushort value=76
TestCastFloat8 (EURUSD,H1)      9 input value =10.000000  Hex float8 = [4d]  ushort value=77
TestCastFloat8 (EURUSD,H1)      10 input value =12.000000  Hex float8 = [4e]  ushort value=78
TestCastFloat8 (EURUSD,H1)      11 input value =12.000000  Hex float8 = [4e]  ushort value=78
TestCastFloat8 (EURUSD,H1)      12 input value =12.000000  Hex float8 = [4e]  ushort value=78
TestCastFloat8 (EURUSD,H1)      13 input value =14.000000  Hex float8 = [4f]  ushort value=79
TestCastFloat8 (EURUSD,H1)      14 input value =16.000000  Hex float8 = [50]  ushort value=80
TestCastFloat8 (EURUSD,H1)      ONNX input array: [64,68,70,72,73,74,75,76,76,77,78,78,78,79,80]
TestCastFloat8 (EURUSD,H1)      ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,8.0,10.0,12.0,12.0,12.0,14.0,16.0]
TestCastFloat8 (EURUSD,H1)      0 output float 1.000000 = [00,00,80,3f] difference=0.000000
TestCastFloat8 (EURUSD,H1)      1 output float 2.000000 = [00,00,00,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      2 output float 3.000000 = [00,00,40,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      3 output float 4.000000 = [00,00,80,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      4 output float 5.000000 = [00,00,a0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      5 output float 6.000000 = [00,00,c0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      6 output float 7.000000 = [00,00,e0,40] difference=0.000000
TestCastFloat8 (EURUSD,H1)      7 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      8 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      9 output float 10.000000 = [00,00,20,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      10 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      11 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      12 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      13 output float 14.000000 = [00,00,60,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      14 output float 16.000000 = [00,00,80,41] difference=0.000000
TestCastFloat8 (EURUSD,H1)      RunCastFloat8ToFloat(FLOAT_FP8_E5M2FNUZ): sum_error=0.000000
TestCastFloat8 (EURUSD,H1)



2. Пример использования ONNX для повышения разрешения изображений

В этом разделе мы рассмотрим пример использования моделей ESRGAN для увеличения разрешения изображений.

ESRGAN, или Enhanced Super-Resolution Generative Adversarial Networks, представляет собой мощную архитектуру нейронных сетей, разработанную для решения задачи суперразрешения изображений. ESRGAN разработан с целью улучшения качества изображений, повышая их разрешение до более высокого уровня. Это осуществляется путем обучения глубокой нейронной сети на большом наборе данных низкого разрешения и соответствующих им высококачественных изображений. ESRGAN использует архитектуру генеративно-состязательных сетей (GANs), которая состоит из двух основных компонентов: генератора и дискриминатора. Генератор отвечает за создание изображений высокого разрешения, тогда как дискриминатор обучается отличать сгенерированные изображения от реальных.

В основе архитектуры ESRGAN лежат резидуальные блоки, которые помогают извлекать и сохранять важные признаки изображений на разных уровнях абстракции. Это позволяет сети эффективно восстанавливать детали и текстуры на изображениях с высоким качеством.

Для достижения высокого качества и общности в решении задачи суперразрешения, ESRGAN требует обширных наборов данных для обучения. Это позволяет сети изучать различные стили и характеристики изображений и делает ее более адаптивной к различным типам входных данных. ESRGAN может быть использован для улучшения качества изображений во многих областях, включая фотографии, медицинскую диагностику, кино и видеопроизводство, графический дизайн и многое другое. Его гибкость и эффективность делают его одним из ведущих методов в области суперразрешения изображений.

ESRGAN представляет собой важный шаг вперед в области обработки изображений и искусственного интеллекта, открывая новые возможности для создания и улучшения изображений.


2.1. Пример исполнения ONNX-модели с float32

Для исполнения примера требуется скачать файл https://github.com/amannm/super-resolution-service/blob/main/models/esrgan.onnx и скопировать его в папку \MQL5\Scripts\models.

Модель ESRGAN.onnx содержит ~1200 ONNX-операций, начальные из них представлены на рис.12

Рис. Модель ESRGAN в MetaEditor

Рис.12. Модель ESRGAN в MetaEditor



Рис. Модель ESRGAN в Netron

Рис.13. Модель ESRGAN в Netron


Приведенный ниже код представляет собой демонстрацию увеличения размера изображения в 4 раза с использованием ESRGAN.onnx.

Он начинается с загрузки модели esrgan.onnxm, затем выбирается и загружается исходное изображение в формате BMP. После этого изображение конвертируется в отдельные каналы RGB, которые подаются на вход модели. Модель выполняет процесс увеличения размера изображения в 4 раза, после чего полученное увеличенное изображение проходит обратное преобразование и подготавливается к отображению.

Для отображения используется библиотека Canvas, а для выполнения модели - библиотека ONNX Runtime. После выполнения программы увеличенное изображение сохраняется в файл с добавлением "_upscaled" к имени исходного файла. Основные функции включают предобработку и постобработку изображения, а также выполнение модели для увеличения размера изображения.

//+------------------------------------------------------------------+
//|                                                       ESRGAN.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"
//+------------------------------------------------------------------+
//| 4x image upscaling demo using ESRGAN                             |
//| esrgan.onnx model from                                           |
//| https://github.com/amannm/super-resolution-service/              |
//+------------------------------------------------------------------+
//| Xintao Wang et al (2018)                                         |
//| ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks|
//| https://arxiv.org/abs/1809.00219                                 |
//+------------------------------------------------------------------+
#resource "models\\esrgan.onnx" as uchar ExtModel[];
#include <Canvas\Canvas.mqh>
//+------------------------------------------------------------------+
//| clamp                                                            |
//+------------------------------------------------------------------+
float clamp(float value, float minValue, float maxValue)
  {
   return MathMin(MathMax(value, minValue), maxValue);
  }
//+------------------------------------------------------------------+
//| Preprocessing                                                    |
//+------------------------------------------------------------------+
bool Preprocessing(float &data[],uint &image_data[],int &image_width,int &image_height)
  {
//--- checkup
   if(image_height==0 || image_width==0)
      return(false);
//--- prepare destination array with separated RGB channels for ONNX model
   int data_count=3*image_width*image_height;

   if(ArrayResize(data,data_count)!=data_count)
     {
      Print("ArrayResize failed");
      return(false);
     }
//--- converting
   for(int y=0; y<image_height; y++)
      for(int x=0; x<image_width; x++)
        {
         //--- load source RGB
         int   offset=y*image_width+x;
         uint  clr   =image_data[offset];
         uchar r     =GETRGBR(clr);
         uchar g     =GETRGBG(clr);
         uchar b     =GETRGBB(clr);
         //--- store RGB components as separated channels
         int offset_ch1=0*image_width*image_height+offset;
         int offset_ch2=1*image_width*image_height+offset;
         int offset_ch3=2*image_width*image_height+offset;

         data[offset_ch1]=r/255.0f;
         data[offset_ch2]=g/255.0f;
         data[offset_ch3]=b/255.0f;
        }
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| PostProcessing                                                   |
//+------------------------------------------------------------------+
bool PostProcessing(const float &data[], uint &image_data[], const int &image_width, const int &image_height)
  {
//--- checks
   if(image_height == 0 || image_width == 0)
      return(false);

   int data_count=image_width*image_height;
   
   if(ArraySize(data)!=3*data_count)
      return(false);
   if(ArrayResize(image_data,data_count)!=data_count)
      return(false);
//---
   for(int y=0; y<image_height; y++)
      for(int x=0; x<image_width; x++)
        {
         int offset    =y*image_width+x;
         int offset_ch1=0*image_width*image_height+offset;
         int offset_ch2=1*image_width*image_height+offset;
         int offset_ch3=2*image_width*image_height+offset;
         //--- rescale to [0..255]
         float r=clamp(data[offset_ch1]*255,0,255);
         float g=clamp(data[offset_ch2]*255,0,255);
         float b=clamp(data[offset_ch3]*255,0,255);
         //--- set color image_data
         image_data[offset]=XRGB(uchar(r),uchar(g),uchar(b));
        }
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| ShowImage                                                        |
//+------------------------------------------------------------------+
bool ShowImage(CCanvas &canvas,const string name,const int x0,const int y0,const int image_width,const int image_height, const uint &image_data[])
  {
   if(ArraySize(image_data)==0 || name=="")
      return(false);
//--- prepare canvas
   canvas.CreateBitmapLabel(name,x0,y0,image_width,image_height,COLOR_FORMAT_XRGB_NOALPHA);
//--- copy image to canvas
   for(int y=0; y<image_height; y++)
      for(int x=0; x<image_width; x++)
         canvas.PixelSet(x,y,image_data[y*image_width+x]);
//--- ready to draw
   canvas.Update(true);
   return(true);
  }
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
int OnStart(void)
  {
//--- select BMP from <data folder>\MQL5\Files
   string image_path[1];

   if(FileSelectDialog("Select BMP image",NULL,"Bitmap files (*.bmp)|*.bmp",FSD_FILE_MUST_EXIST,image_path,"lenna-original4.bmp")!=1)
     {
      Print("file not selected");
      return(-1);
     }
//--- load BMP into array
   uint image_data[];
   int  image_width;
   int  image_height;

   if(!CCanvas::LoadBitmap(image_path[0],image_data,image_width,image_height))
     {
      PrintFormat("CCanvas::LoadBitmap failed with error %d",GetLastError());
      return(-1);
     }
//--- convert RGB image to separated RGB channels
   float input_data[];
   Preprocessing(input_data,image_data,image_width,image_height);
   PrintFormat("input array size=%d",ArraySize(input_data));
//--- load model
   long model_handle=OnnxCreateFromBuffer(ExtModel,ONNX_DEFAULT);

   if(model_handle==INVALID_HANDLE)
     {
      PrintFormat("OnnxCreate error %d",GetLastError());
      return(-1);
     }

   PrintFormat("model loaded successfully");
   PrintFormat("original:  width=%d, height=%d  Size=%d",image_width,image_height,ArraySize(image_data));
//--- set input shape
   ulong input_shape[]={1,3,image_height,image_width};

   if(!OnnxSetInputShape(model_handle,0,input_shape))
     {
      PrintFormat("error in OnnxSetInputShape. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(-1);
     }
//--- upscaled image size
   int   new_image_width =4*image_width;
   int   new_image_height=4*image_height;
   ulong output_shape[]= {1,3,new_image_height,new_image_width};

   if(!OnnxSetOutputShape(model_handle,0,output_shape))
     {
      PrintFormat("error in OnnxSetOutputShape. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(-1);
     }
//--- run the model
   float output_data[];
   int new_data_count=3*new_image_width*new_image_height;
   if(ArrayResize(output_data,new_data_count)!=new_data_count)
     {
      OnnxRelease(model_handle);
      return(-1);
     }

   if(!OnnxRun(model_handle,ONNX_DEBUG_LOGS,input_data,output_data))
     {
      PrintFormat("error in OnnxRun. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(-1);
     }

   Print("model successfully executed, output data size ",ArraySize(output_data));
   OnnxRelease(model_handle);
//--- postprocessing
   uint new_image[];
   PostProcessing(output_data,new_image,new_image_width,new_image_height);
//--- show images
   CCanvas canvas_original,canvas_scaled;
   ShowImage(canvas_original,"original_image",new_image_width,0,image_width,image_height,image_data);
   ShowImage(canvas_scaled,"upscaled_image",0,0,new_image_width,new_image_height,new_image);
//--- save upscaled image
   StringReplace(image_path[0],".bmp","_upscaled.bmp");
   Print(ResourceSave(canvas_scaled.ResourceName(),image_path[0]));
//---
   while(!IsStopped())
      Sleep(100);

   return(0);
  }
//+------------------------------------------------------------------+

Результат:

Рис. Увеличение изображения 160x200 в 4 раза при помощи модели ESRGAN.onnx

Рис.14. Результат работы модели ESRGAN.onnx (160x200 -> 640x800)

В данном примере изображение 160x200 было увеличено в 4 раза (до 640x800) при помощи модели ESRGAN.onnx.


2.2. Пример исполнения ONNX-модели с float16

Для конвертации моделей в float16 воспользуемся методом, описанным в Create Float16 and Mixed Precision Models.

# Copyright 2024, MetaQuotes Ltd.
# https://www.mql5.com

import onnx
from onnxconverter_common import float16

from sys import argv

# Define the path for saving the model
data_path = argv[0]
last_index = data_path.rfind("\\") + 1
data_path = data_path[0:last_index]

# конвертация модели в float16
model_path = data_path+'\\models\\esrgan.onnx'
modelfp16_path = data_path+'\\models\\esrgan_float16.onnx'

model = onnx.load(model_path)
model_fp16 = float16.convert_float_to_float16(model)
onnx.save(model_fp16, modelfp16_path)

После конвертации размер файла уменьшился вдвое (с 64Mb до 32Mb).

Изменения в коде минимальные:

//+------------------------------------------------------------------+
//|                                               ESRGAN_float16.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"
//+------------------------------------------------------------------+
//| 4x image upscaling demo using ESRGAN                             |
//| esrgan.onnx model from                                           |
//| https://github.com/amannm/super-resolution-service/              |
//+------------------------------------------------------------------+
//| Xintao Wang et al (2018)                                         |
//| ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks|
//| https://arxiv.org/abs/1809.00219                                 |
//+------------------------------------------------------------------+
#resource "models\\esrgan_float16.onnx" as uchar ExtModel[];
#include <Canvas\Canvas.mqh>
//+------------------------------------------------------------------+
//| clamp                                                            |
//+------------------------------------------------------------------+
float clamp(float value, float minValue, float maxValue)
  {
   return MathMin(MathMax(value, minValue), maxValue);
  }
//+------------------------------------------------------------------+
//| Preprocessing                                                    |
//+------------------------------------------------------------------+
bool Preprocessing(float &data[],uint &image_data[],int &image_width,int &image_height)
  {
//--- checkup
   if(image_height==0 || image_width==0)
      return(false);
//--- prepare destination array with separated RGB channels for ONNX model
   int data_count=3*image_width*image_height;

   if(ArrayResize(data,data_count)!=data_count)
     {
      Print("ArrayResize failed");
      return(false);
     }
//--- converting
   for(int y=0; y<image_height; y++)
      for(int x=0; x<image_width; x++)
        {
         //--- load source RGB
         int   offset=y*image_width+x;
         uint  clr   =image_data[offset];
         uchar r     =GETRGBR(clr);
         uchar g     =GETRGBG(clr);
         uchar b     =GETRGBB(clr);
         //--- store RGB components as separated channels
         int offset_ch1=0*image_width*image_height+offset;
         int offset_ch2=1*image_width*image_height+offset;
         int offset_ch3=2*image_width*image_height+offset;

         data[offset_ch1]=r/255.0f;
         data[offset_ch2]=g/255.0f;
         data[offset_ch3]=b/255.0f;
        }
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| PostProcessing                                                   |
//+------------------------------------------------------------------+
bool PostProcessing(const float &data[], uint &image_data[], const int &image_width, const int &image_height)
  {
//--- checks
   if(image_height == 0 || image_width == 0)
      return(false);

   int data_count=image_width*image_height;
   
   if(ArraySize(data)!=3*data_count)
      return(false);
   if(ArrayResize(image_data,data_count)!=data_count)
      return(false);
//---
   for(int y=0; y<image_height; y++)
      for(int x=0; x<image_width; x++)
        {
         int offset    =y*image_width+x;
         int offset_ch1=0*image_width*image_height+offset;
         int offset_ch2=1*image_width*image_height+offset;
         int offset_ch3=2*image_width*image_height+offset;
         //--- rescale to [0..255]
         float r=clamp(data[offset_ch1]*255,0,255);
         float g=clamp(data[offset_ch2]*255,0,255);
         float b=clamp(data[offset_ch3]*255,0,255);
         //--- set color image_data
         image_data[offset]=XRGB(uchar(r),uchar(g),uchar(b));
        }
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| ShowImage                                                        |
//+------------------------------------------------------------------+
bool ShowImage(CCanvas &canvas,const string name,const int x0,const int y0,const int image_width,const int image_height, const uint &image_data[])
  {
   if(ArraySize(image_data)==0 || name=="")
      return(false);
//--- prepare canvas
   canvas.CreateBitmapLabel(name,x0,y0,image_width,image_height,COLOR_FORMAT_XRGB_NOALPHA);
//--- copy image to canvas
   for(int y=0; y<image_height; y++)
      for(int x=0; x<image_width; x++)
         canvas.PixelSet(x,y,image_data[y*image_width+x]);
//--- ready to draw
   canvas.Update(true);
   return(true);
  }
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
int OnStart(void)
  {
//--- select BMP from <data folder>\MQL5\Files
   string image_path[1];

   if(FileSelectDialog("Select BMP image",NULL,"Bitmap files (*.bmp)|*.bmp",FSD_FILE_MUST_EXIST,image_path,"lenna.bmp")!=1)
     {
      Print("file not selected");
      return(-1);
     }
//--- load BMP into array
   uint image_data[];
   int  image_width;
   int  image_height;

   if(!CCanvas::LoadBitmap(image_path[0],image_data,image_width,image_height))
     {
      PrintFormat("CCanvas::LoadBitmap failed with error %d",GetLastError());
      return(-1);
     }
//--- convert RGB image to separated RGB channels
   float input_data[];
   Preprocessing(input_data,image_data,image_width,image_height);
   PrintFormat("input array size=%d",ArraySize(input_data));
   
   ushort input_data_float16[];
   if(!ArrayToFP16(input_data_float16,input_data,FLOAT_FP16))
     {
      Print("error in ArrayToFP16. error code=",GetLastError());
      return(false);
     }   
//--- load model
   long model_handle=OnnxCreateFromBuffer(ExtModel,ONNX_DEFAULT);
   if(model_handle==INVALID_HANDLE)
     {
      PrintFormat("OnnxCreate error %d",GetLastError());
      return(-1);
     }

   PrintFormat("model loaded successfully");
   PrintFormat("original:  width=%d, height=%d  Size=%d",image_width,image_height,ArraySize(image_data));
//--- set input shape
   ulong input_shape[]={1,3,image_height,image_width};

   if(!OnnxSetInputShape(model_handle,0,input_shape))
     {
      PrintFormat("error in OnnxSetInputShape. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(-1);
     }
//--- upscaled image size
   int   new_image_width =4*image_width;
   int   new_image_height=4*image_height;
   ulong output_shape[]= {1,3,new_image_height,new_image_width};

   if(!OnnxSetOutputShape(model_handle,0,output_shape))
     {
      PrintFormat("error in OnnxSetOutputShape. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(-1);
     }
//--- run the model
   float output_data[];
   ushort output_data_float16[];
   int new_data_count=3*new_image_width*new_image_height;
   if(ArrayResize(output_data_float16,new_data_count)!=new_data_count)
     {
      OnnxRelease(model_handle);
      return(-1);
     }

   if(!OnnxRun(model_handle,ONNX_NO_CONVERSION,input_data_float16,output_data_float16))
     {
      PrintFormat("error in OnnxRun. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(-1);
     }

   Print("model successfully executed, output data size ",ArraySize(output_data));
   OnnxRelease(model_handle);
   
   if(!ArrayFromFP16(output_data,output_data_float16,FLOAT_FP16))
     {
      Print("error in ArrayFromFP16. error code=",GetLastError());
      return(false);
     }   
//--- postprocessing
   uint new_image[];
   PostProcessing(output_data,new_image,new_image_width,new_image_height);
//--- show images
   CCanvas canvas_original,canvas_scaled;
   ShowImage(canvas_original,"original_image",new_image_width,0,image_width,image_height,image_data);
   ShowImage(canvas_scaled,"upscaled_image",0,0,new_image_width,new_image_height,new_image);
//--- save upscaled image
   StringReplace(image_path[0],".bmp","_upscaled.bmp");
   Print(ResourceSave(canvas_scaled.ResourceName(),image_path[0]));
//---
   while(!IsStopped())
      Sleep(100);

   return(0);
  }
//+------------------------------------------------------------------+

Изменения в коде, которые потребовались для исполнения модели, сконвертированной в формат float16, выделены цветом.

Результат:

 Рис. Результат работы модели ESRGAN_float16.onnx (160x200 -> 640x800)

 Рис.15. Результат работы модели ESRGAN_float16.onnx (160x200 -> 640x800)


Таким образом, использование чисел float16 вместо float32 позволяет сократить размер файла ONNX-модели в 2 раза (с 64Mb до 32Mb).

При исполнении моделей на числах float16 качество изображений осталось прежним, визуально трудно найти отличия:

 Рис.16. Сравнение результатов работы модели ESRGAN для float и float16

Рис.16. Сравнение результатов работы модели ESRGAN для float и float16


Изменения в коде минимальные, нужно лишь позаботиться о конвертации входных и выходных данных.

В данном случае после конвертации в float16 качество работы модели существенно не изменилось, однако при анализе финансовых данных следует стремиться к расчетам с максимально возможной точностью.


Выводы

Использование новых типов данных для чисел с плавающей точкой позволяет сократить размер ONNX-моделей без существенной потери качества.

Препроцессинг и постпроцессинг данных значительно упрощаются благодаря использованию функций конвертации данных ArrayToFP16/ArrayFromFP16 и ArrayToFP8/ArrayFromFP8.

Для работы с конвертированными ONNX-моделями требуются минимальные изменения в коде.