Trabalho com modelos ONNX nos formatos float16 e float8
Conteúdo
- 1. Novos tipos de dados para trabalhar com modelos ONNX
- 1.1. Formato FP16
- 1.1.1. Testes de execução do operador ONNX Cast para FLOAT16
- 1.1.2. Testes de execução do operador ONNX Cast para BFLOAT16
- 1.2. Formato FP8
- 1.2.1. Formatos fp8_e5m2 e fp8_e4m3
- 1.2.2. Testes de execução do operador ONNX Cast para FLOAT8
- 2. Exemplo de uso do ONNX para aumentar a resolução de imagens
- 2.1. Exemplo de execução do modelo ONNX com float32
- 2.2. Exemplo de execução do modelo ONNX com float16
- Conclusões
Com o desenvolvimento das tecnologias de aprendizado de máquina e inteligência artificial, surge a necessidade de otimizar os processos de trabalho com modelos. A eficiência dos modelos depende diretamente dos formatos de dados utilizados para representá-los. Nos últimos anos, surgiram vários novos tipos de dados projetados especificamente para trabalhar com modelos de aprendizado profundo.
Neste artigo, vamos focar em dois desses novos formatos de dados, especificamente float16 e float8, que estão começando a ser amplamente utilizados nos modelos ONNX modernos. Esses formatos representam alternativas aos formatos de dados de ponto flutuante mais precisos, mas que demandam mais recursos. Eles oferecem uma combinação ótima de desempenho e precisão, tornando-os especialmente atraentes para realizar várias tarefas de aprendizado de máquina. Vamos explorar as principais características e vantagens dos formatos float16 e float8, bem como apresentar funções para sua conversão em float e double padrão.
Isso ajudará desenvolvedores e pesquisadores a entender melhor como usar esses formatos de forma eficiente em seus projetos e modelos. Como exemplo, vamos examinar o funcionamento do modelo ONNX ESRGAN, que é utilizado para melhorar a qualidade das imagens.
1. Novos tipos de dados para trabalhar com modelos ONNX
Para acelerar os cálculos, alguns modelos utilizam tipos de dados com menor precisão, como Float16 e até mesmo Float8.
Para trabalhar com modelos ONNX na linguagem MQL5, foi adicionada a suporte para novos tipos de dados, permitindo trabalhar com representações de números de ponto flutuante de 8 e 16 bits.
O script exibe a lista completa dos elementos da enumeração ENUM_ONNX_DATA_TYPE.
//+------------------------------------------------------------------+ //| ONNX_Data_Types.mq5 | //| Copyright 2024, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #property copyright "Copyright 2024, MetaQuotes Ltd." #property link "https://www.mql5.com" #property version "1.00" //+------------------------------------------------------------------+ //| Script program start function | //+------------------------------------------------------------------+ void OnStart() { //--- for(int i=0; i<21; i++) PrintFormat("%2d %s",i,EnumToString(ENUM_ONNX_DATA_TYPE(i))); }
Resultado:
0: ONNX_DATA_TYPE_UNDEFINED 1: ONNX_DATA_TYPE_FLOAT 2: ONNX_DATA_TYPE_UINT8 3: ONNX_DATA_TYPE_INT8 4: ONNX_DATA_TYPE_UINT16 5: ONNX_DATA_TYPE_INT16 6: ONNX_DATA_TYPE_INT32 7: ONNX_DATA_TYPE_INT64 8: ONNX_DATA_TYPE_STRING 9: ONNX_DATA_TYPE_BOOL 10: ONNX_DATA_TYPE_FLOAT16 11: ONNX_DATA_TYPE_DOUBLE 12: ONNX_DATA_TYPE_UINT32 13: ONNX_DATA_TYPE_UINT64 14: ONNX_DATA_TYPE_COMPLEX64 15: ONNX_DATA_TYPE_COMPLEX128 16: ONNX_DATA_TYPE_BFLOAT16 17: ONNX_DATA_TYPE_FLOAT8E4M3FN 18: ONNX_DATA_TYPE_FLOAT8E4M3FNUZ 19: ONNX_DATA_TYPE_FLOAT8E5M2 20: ONNX_DATA_TYPE_FLOAT8E5M2FNUZ
Assim, agora é possível executar modelos ONNX que trabalham com esses dados.
Além disso, no MQL5, surgiram funções adicionais para conversão de dados:
bool ArrayToFP16(ushort &dst_array[],const float &src_array[],ENUM_FLOAT16_FORMAT fmt); bool ArrayToFP16(ushort &dst_array[],const double &src_array[],ENUM_FLOAT16_FORMAT fmt); bool ArrayToFP8(uchar &dst_array[],const float &src_array[],ENUM_FLOAT8_FORMAT fmt); bool ArrayToFP8(uchar &dst_array[],const double &src_array[],ENUM_FLOAT8_FORMAT fmt); bool ArrayFromFP16(float &dst_array[],const ushort &src_array[],ENUM_FLOAT16_FORMAT fmt); bool ArrayFromFP16(double &dst_array[],const ushort &src_array[],ENUM_FLOAT16_FORMAT fmt); bool ArrayFromFP8(float &dst_array[],const uchar &src_array[],ENUM_FLOAT8_FORMAT fmt); bool ArrayFromFP8(double &dst_array[],const uchar &src_array[],ENUM_FLOAT8_FORMAT fmt);
Como os formatos de números reais de 16 e 8 bits podem variar, no parâmetro fmt das funções de conversão, é necessário especificar qual formato de número deve ser processado.
Para as versões de 16 bits, é utilizada uma nova enumeração ENUM_FLOAT16_FORMAT, que atualmente possui os seguintes valores:
- FLOAT_FP16 — formato padrão de 16 bits, também conhecido como half.
- FLOAT_BFP16 — formato especial brain float point.
- FLOAT_FP8_E4M3FN — número de ponto flutuante de 8 bits, com 4 bits para o expoente e 3 bits para a mantissa. Geralmente usado como coeficientes.
- FLOAT_FP8_E4M3FNUZ — número de ponto flutuante de 8 bits, com 4 bits para o expoente e 3 bits para a mantissa. Suporta NaN, não suporta zero negativo e Inf. Geralmente usado como coeficientes.
- FLOAT_FP8_E5M2FN — número de ponto flutuante de 8 bits, com 5 bits para o expoente e 2 bits para a mantissa. Suporta NaN e Inf. Geralmente usado para gradientes.
- FLOAT_FP8_E5M2FNUZ — número de ponto flutuante de 8 bits, com 5 bits para o expoente e 2 bits para a mantissa. Suporta NaN e Inf, não suporta zero negativo. Também usado para gradientes.
1.1. Formato FP16
Os formatos FLOAT16 e BFLOAT16 são tipos de dados usados para representar números de ponto flutuante.
FLOAT16, também conhecido como meia precisão ou formato "half-precision float", utiliza 16 bits para representar um número de ponto flutuante. Este formato oferece um equilíbrio entre precisão e eficiência computacional. FLOAT16 é amplamente utilizado em aprendizado profundo e redes neurais, onde é necessária alta performance no processamento de grandes volumes de dados. Este formato permite acelerar os cálculos ao reduzir o tamanho dos números, o que é especialmente importante ao treinar redes neurais profundas em unidades de processamento gráfico (GPU).
BFLOAT16 (ou Brain Floating Point 16) também utiliza 16 bits, mas difere do FLOAT16 na maneira como os números são representados. Neste formato, 8 bits são alocados para representar o expoente, e os 7 bits restantes são usados para representar a mantissa. Este formato foi desenvolvido para uso em aprendizado profundo e inteligência artificial, especialmente nos processadores Google Tensor Processing Unit (TPU). BFLOAT16 possui boa performance no treinamento de redes neurais e pode ser utilizado de maneira eficiente para acelerar cálculos.
Ambos os formatos têm suas vantagens e limitações. FLOAT16 proporciona maior precisão, mas requer mais recursos para armazenamento e cálculos. BFLOAT16, por outro lado, oferece maior performance e eficiência no processamento de dados, mas pode ser menos preciso.
Fig.1. Formatos de representação de bits de números de ponto flutuante FLOAT16 e BFLOAT16
Tab.1. Números de ponto flutuante no formato FLOAT16
1.1.1. Testes de execução do operador ONNX Cast para FLOAT16
Como ilustração, considere a tarefa de conversão de dados do tipo FLOAT16 para os tipos float e double.
Modelos ONNX com operação Cast:
- https://github.com/onnx/onnx/tree/main/onnx/backend/test/data/node/test_cast_FLOAT16_to_FLOAT
- https://github.com/onnx/onnx/tree/main/onnx/backend/test/data/node/test_cast_FLOAT16_to_DOUBLE
Fig.2. Propriedades de entrada e saída dos modelos test_cast_FLOAT16_to_DOUBLE.onnx
Fig.3. Propriedades de entrada e saída do modelo test_cast_FLOAT16_to_FLOAT.onnx
Como indicado nas descrições das propriedades dos modelos ONNX, os dados de entrada devem ser do tipo ONNX_DATA_TYPE_FLOAT16, e os dados de saída serão retornados no formato ONNX_DATA_TYPE_FLOAT.
Para a conversão dos valores, utilizaremos as funções ArrayToFP16() e ArrayFromFP16() com o parâmetro FLOAT_FP16.
Exemplo:
//+------------------------------------------------------------------+ //| TestCastFloat16.mq5 | //| Copyright 2024, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #property copyright "Copyright 2024, MetaQuotes Ltd." #property link "https://www.mql5.com" #property version "1.00" #resource "models\\test_cast_FLOAT16_to_DOUBLE.onnx" as const uchar ExtModel1[]; #resource "models\\test_cast_FLOAT16_to_FLOAT.onnx" as const uchar ExtModel2[]; //+------------------------------------------------------------------+ //| union for data conversion | //+------------------------------------------------------------------+ template<typename T> union U { uchar uc[sizeof(T)]; T value; }; //+------------------------------------------------------------------+ //| ArrayToString | //+------------------------------------------------------------------+ template<typename T> string ArrayToString(const T &data[],uint length=16) { string res; for(uint n=0; n<MathMin(length,data.Size()); n++) res+="," + StringFormat("%.2x",data[n]); StringSetCharacter(res,0,'['); return res+"]"; } //+------------------------------------------------------------------+ //| PatchONNXModel | //+------------------------------------------------------------------+ void PatchONNXModel(const uchar &original_model[],uchar &patched_model[]) { ArrayCopy(patched_model,original_model,0,0,WHOLE_ARRAY); //--- special ONNX model patch(IR=9,Opset=20) patched_model[1]=0x09; patched_model[ArraySize(patched_model)-1]=0x14; } //+------------------------------------------------------------------+ //| CreateModel | //+------------------------------------------------------------------+ bool CreateModel(long &model_handle,const uchar &model[]) { model_handle=INVALID_HANDLE; ulong flags=ONNX_DEFAULT; //ulong flags=ONNX_DEBUG_LOGS; //--- model_handle=OnnxCreateFromBuffer(model,flags); if(model_handle==INVALID_HANDLE) return(false); //--- return(true); } //+------------------------------------------------------------------+ //| PrepareShapes | //+------------------------------------------------------------------+ bool PrepareShapes(long model_handle) { ulong input_shape1[]= {3,4}; if(!OnnxSetInputShape(model_handle,0,input_shape1)) { PrintFormat("error in OnnxSetInputShape for input1. error code=%d",GetLastError()); //-- OnnxRelease(model_handle); return(false); } //--- ulong output_shape[]= {3,4}; if(!OnnxSetOutputShape(model_handle,0,output_shape)) { PrintFormat("error in OnnxSetOutputShape for output. error code=%d",GetLastError()); //-- OnnxRelease(model_handle); return(false); } //--- return(true); } //+------------------------------------------------------------------+ //| RunCastFloat16ToDouble | //+------------------------------------------------------------------+ bool RunCastFloat16ToDouble(long model_handle) { PrintFormat("test=%s",__FUNCTION__); double test_data[12]= {1,2,3,4,5,6,7,8,9,10,11,12}; ushort data_uint16[12]; if(!ArrayToFP16(data_uint16,test_data,FLOAT_FP16)) { Print("error in ArrayToFP16. error code=",GetLastError()); return(false); } Print("test array:"); ArrayPrint(test_data); Print("ArrayToFP16:"); ArrayPrint(data_uint16); U<ushort> input_float16_values[3*4]; U<double> output_double_values[3*4]; float test_data_float[]; if(!ArrayFromFP16(test_data_float,data_uint16,FLOAT_FP16)) { Print("error in ArrayFromFP16. error code=",GetLastError()); return(false); } for(int i=0; i<12; i++) { input_float16_values[i].value=data_uint16[i]; PrintFormat("%d input value =%f Hex float16 = %s ushort value=%d",i,test_data_float[i],ArrayToString(input_float16_values[i].uc),input_float16_values[i].value); } Print("ONNX input array:"); ArrayPrint(input_float16_values); bool res=OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float16_values,output_double_values); if(!res) { PrintFormat("error in OnnxRun. error code=%d",GetLastError()); return(false); } Print("ONNX output array:"); ArrayPrint(output_double_values); //--- double sum_error=0.0; for(int i=0; i<12; i++) { double delta=test_data[i]-output_double_values[i].value; sum_error+=MathAbs(delta); PrintFormat("%d output double %f = %s difference=%f",i,output_double_values[i].value,ArrayToString(output_double_values[i].uc),delta); } //--- PrintFormat("test=%s sum_error=%f",__FUNCTION__,sum_error); //--- return(true); } //+------------------------------------------------------------------+ //| RunCastFloat16ToFloat | //+------------------------------------------------------------------+ bool RunCastFloat16ToFloat(long model_handle) { PrintFormat("test=%s",__FUNCTION__); double test_data[12]= {1,2,3,4,5,6,7,8,9,10,11,12}; ushort data_uint16[12]; if(!ArrayToFP16(data_uint16,test_data,FLOAT_FP16)) { Print("error in ArrayToFP16. error code=",GetLastError()); return(false); } Print("test array:"); ArrayPrint(test_data); Print("ArrayToFP16:"); ArrayPrint(data_uint16); U<ushort> input_float16_values[3*4]; U<float> output_float_values[3*4]; float test_data_float[]; if(!ArrayFromFP16(test_data_float,data_uint16,FLOAT_FP16)) { Print("error in ArrayFromFP16. error code=",GetLastError()); return(false); } for(int i=0; i<12; i++) { input_float16_values[i].value=data_uint16[i]; PrintFormat("%d input value =%f Hex float16 = %s ushort value=%d",i,test_data_float[i],ArrayToString(input_float16_values[i].uc),input_float16_values[i].value); } Print("ONNX input array:"); ArrayPrint(input_float16_values); bool res=OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float16_values,output_float_values); if(!res) { PrintFormat("error in OnnxRun. error code=%d",GetLastError()); return(false); } Print("ONNX output array:"); ArrayPrint(output_float_values); //--- double sum_error=0.0; for(int i=0; i<12; i++) { double delta=test_data[i]-(double)output_float_values[i].value; sum_error+=MathAbs(delta); PrintFormat("%d output float %f = %s difference=%f",i,output_float_values[i].value,ArrayToString(output_float_values[i].uc),delta); } //--- PrintFormat("test=%s sum_error=%f",__FUNCTION__,sum_error); //--- return(true); } //+------------------------------------------------------------------+ //| TestCastFloat16ToFloat | //+------------------------------------------------------------------+ bool TestCastFloat16ToFloat(const uchar &res_model[]) { uchar model[]; PatchONNXModel(res_model,model); //--- get model handle long model_handle=INVALID_HANDLE; //--- get model handle if(!CreateModel(model_handle,model)) return(false); //--- prepare input and output shapes if(!PrepareShapes(model_handle)) return(false); //--- run ONNX model if(!RunCastFloat16ToFloat(model_handle)) return(false); //--- release model handle OnnxRelease(model_handle); //--- return(true); } //+------------------------------------------------------------------+ //| TestCastFloat16ToDouble | //+------------------------------------------------------------------+ bool TestCastFloat16ToDouble(const uchar &res_model[]) { uchar model[]; PatchONNXModel(res_model,model); //--- long model_handle=INVALID_HANDLE; //--- get model handle if(!CreateModel(model_handle,model)) return(false); //--- prepare input and output shapes if(!PrepareShapes(model_handle)) return(false); //--- run ONNX model if(!RunCastFloat16ToDouble(model_handle)) return(false); //--- release model handle OnnxRelease(model_handle); //--- return(true); } //+------------------------------------------------------------------+ //| Script program start function | //+------------------------------------------------------------------+ int OnStart(void) { if(!TestCastFloat16ToDouble(ExtModel1)) return 1; if(!TestCastFloat16ToFloat(ExtModel2)) return 1; //--- return 0; } //+------------------------------------------------------------------+
Resultado:
TestCastFloat16 (EURUSD,H1) test=RunCastFloat16ToDouble TestCastFloat16 (EURUSD,H1) test array: TestCastFloat16 (EURUSD,H1) 1.00000 2.00000 3.00000 4.00000 5.00000 6.00000 7.00000 8.00000 9.00000 10.00000 11.00000 12.00000 TestCastFloat16 (EURUSD,H1) ArrayToFP16: TestCastFloat16 (EURUSD,H1) 15360 16384 16896 17408 17664 17920 18176 18432 18560 18688 18816 18944 TestCastFloat16 (EURUSD,H1) 0 input value =1.000000 Hex float16 = [00,3c] ushort value=15360 TestCastFloat16 (EURUSD,H1) 1 input value =2.000000 Hex float16 = [00,40] ushort value=16384 TestCastFloat16 (EURUSD,H1) 2 input value =3.000000 Hex float16 = [00,42] ushort value=16896 TestCastFloat16 (EURUSD,H1) 3 input value =4.000000 Hex float16 = [00,44] ushort value=17408 TestCastFloat16 (EURUSD,H1) 4 input value =5.000000 Hex float16 = [00,45] ushort value=17664 TestCastFloat16 (EURUSD,H1) 5 input value =6.000000 Hex float16 = [00,46] ushort value=17920 TestCastFloat16 (EURUSD,H1) 6 input value =7.000000 Hex float16 = [00,47] ushort value=18176 TestCastFloat16 (EURUSD,H1) 7 input value =8.000000 Hex float16 = [00,48] ushort value=18432 TestCastFloat16 (EURUSD,H1) 8 input value =9.000000 Hex float16 = [80,48] ushort value=18560 TestCastFloat16 (EURUSD,H1) 9 input value =10.000000 Hex float16 = [00,49] ushort value=18688 TestCastFloat16 (EURUSD,H1) 10 input value =11.000000 Hex float16 = [80,49] ushort value=18816 TestCastFloat16 (EURUSD,H1) 11 input value =12.000000 Hex float16 = [00,4a] ushort value=18944 TestCastFloat16 (EURUSD,H1) ONNX input array: TestCastFloat16 (EURUSD,H1) [uc] [value] TestCastFloat16 (EURUSD,H1) [ 0] ... 15360 TestCastFloat16 (EURUSD,H1) [ 1] ... 16384 TestCastFloat16 (EURUSD,H1) [ 2] ... 16896 TestCastFloat16 (EURUSD,H1) [ 3] ... 17408 TestCastFloat16 (EURUSD,H1) [ 4] ... 17664 TestCastFloat16 (EURUSD,H1) [ 5] ... 17920 TestCastFloat16 (EURUSD,H1) [ 6] ... 18176 TestCastFloat16 (EURUSD,H1) [ 7] ... 18432 TestCastFloat16 (EURUSD,H1) [ 8] ... 18560 TestCastFloat16 (EURUSD,H1) [ 9] ... 18688 TestCastFloat16 (EURUSD,H1) [10] ... 18816 TestCastFloat16 (EURUSD,H1) [11] ... 18944 TestCastFloat16 (EURUSD,H1) ONNX output array: TestCastFloat16 (EURUSD,H1) [uc] [value] TestCastFloat16 (EURUSD,H1) [ 0] ... 1.00000 TestCastFloat16 (EURUSD,H1) [ 1] ... 2.00000 TestCastFloat16 (EURUSD,H1) [ 2] ... 3.00000 TestCastFloat16 (EURUSD,H1) [ 3] ... 4.00000 TestCastFloat16 (EURUSD,H1) [ 4] ... 5.00000 TestCastFloat16 (EURUSD,H1) [ 5] ... 6.00000 TestCastFloat16 (EURUSD,H1) [ 6] ... 7.00000 TestCastFloat16 (EURUSD,H1) [ 7] ... 8.00000 TestCastFloat16 (EURUSD,H1) [ 8] ... 9.00000 TestCastFloat16 (EURUSD,H1) [ 9] ... 10.00000 TestCastFloat16 (EURUSD,H1) [10] ... 11.00000 TestCastFloat16 (EURUSD,H1) [11] ... 12.00000 TestCastFloat16 (EURUSD,H1) 0 output double 1.000000 = [00,00,00,00,00,00,f0,3f] difference=0.000000 TestCastFloat16 (EURUSD,H1) 1 output double 2.000000 = [00,00,00,00,00,00,00,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 2 output double 3.000000 = [00,00,00,00,00,00,08,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 3 output double 4.000000 = [00,00,00,00,00,00,10,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 4 output double 5.000000 = [00,00,00,00,00,00,14,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 5 output double 6.000000 = [00,00,00,00,00,00,18,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 6 output double 7.000000 = [00,00,00,00,00,00,1c,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 7 output double 8.000000 = [00,00,00,00,00,00,20,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 8 output double 9.000000 = [00,00,00,00,00,00,22,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 9 output double 10.000000 = [00,00,00,00,00,00,24,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 10 output double 11.000000 = [00,00,00,00,00,00,26,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 11 output double 12.000000 = [00,00,00,00,00,00,28,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) test=RunCastFloat16ToDouble sum_error=0.000000 TestCastFloat16 (EURUSD,H1) test=RunCastFloat16ToFloat TestCastFloat16 (EURUSD,H1) test array: TestCastFloat16 (EURUSD,H1) 1.00000 2.00000 3.00000 4.00000 5.00000 6.00000 7.00000 8.00000 9.00000 10.00000 11.00000 12.00000 TestCastFloat16 (EURUSD,H1) ArrayToFP16: TestCastFloat16 (EURUSD,H1) 15360 16384 16896 17408 17664 17920 18176 18432 18560 18688 18816 18944 TestCastFloat16 (EURUSD,H1) 0 input value =1.000000 Hex float16 = [00,3c] ushort value=15360 TestCastFloat16 (EURUSD,H1) 1 input value =2.000000 Hex float16 = [00,40] ushort value=16384 TestCastFloat16 (EURUSD,H1) 2 input value =3.000000 Hex float16 = [00,42] ushort value=16896 TestCastFloat16 (EURUSD,H1) 3 input value =4.000000 Hex float16 = [00,44] ushort value=17408 TestCastFloat16 (EURUSD,H1) 4 input value =5.000000 Hex float16 = [00,45] ushort value=17664 TestCastFloat16 (EURUSD,H1) 5 input value =6.000000 Hex float16 = [00,46] ushort value=17920 TestCastFloat16 (EURUSD,H1) 6 input value =7.000000 Hex float16 = [00,47] ushort value=18176 TestCastFloat16 (EURUSD,H1) 7 input value =8.000000 Hex float16 = [00,48] ushort value=18432 TestCastFloat16 (EURUSD,H1) 8 input value =9.000000 Hex float16 = [80,48] ushort value=18560 TestCastFloat16 (EURUSD,H1) 9 input value =10.000000 Hex float16 = [00,49] ushort value=18688 TestCastFloat16 (EURUSD,H1) 10 input value =11.000000 Hex float16 = [80,49] ushort value=18816 TestCastFloat16 (EURUSD,H1) 11 input value =12.000000 Hex float16 = [00,4a] ushort value=18944 TestCastFloat16 (EURUSD,H1) ONNX input array: TestCastFloat16 (EURUSD,H1) [uc] [value] TestCastFloat16 (EURUSD,H1) [ 0] ... 15360 TestCastFloat16 (EURUSD,H1) [ 1] ... 16384 TestCastFloat16 (EURUSD,H1) [ 2] ... 16896 TestCastFloat16 (EURUSD,H1) [ 3] ... 17408 TestCastFloat16 (EURUSD,H1) [ 4] ... 17664 TestCastFloat16 (EURUSD,H1) [ 5] ... 17920 TestCastFloat16 (EURUSD,H1) [ 6] ... 18176 TestCastFloat16 (EURUSD,H1) [ 7] ... 18432 TestCastFloat16 (EURUSD,H1) [ 8] ... 18560 TestCastFloat16 (EURUSD,H1) [ 9] ... 18688 TestCastFloat16 (EURUSD,H1) [10] ... 18816 TestCastFloat16 (EURUSD,H1) [11] ... 18944 TestCastFloat16 (EURUSD,H1) ONNX output array: TestCastFloat16 (EURUSD,H1) [uc] [value] TestCastFloat16 (EURUSD,H1) [ 0] ... 1.00000 TestCastFloat16 (EURUSD,H1) [ 1] ... 2.00000 TestCastFloat16 (EURUSD,H1) [ 2] ... 3.00000 TestCastFloat16 (EURUSD,H1) [ 3] ... 4.00000 TestCastFloat16 (EURUSD,H1) [ 4] ... 5.00000 TestCastFloat16 (EURUSD,H1) [ 5] ... 6.00000 TestCastFloat16 (EURUSD,H1) [ 6] ... 7.00000 TestCastFloat16 (EURUSD,H1) [ 7] ... 8.00000 TestCastFloat16 (EURUSD,H1) [ 8] ... 9.00000 TestCastFloat16 (EURUSD,H1) [ 9] ... 10.00000 TestCastFloat16 (EURUSD,H1) [10] ... 11.00000 TestCastFloat16 (EURUSD,H1) [11] ... 12.00000 TestCastFloat16 (EURUSD,H1) 0 output float 1.000000 = [00,00,80,3f] difference=0.000000 TestCastFloat16 (EURUSD,H1) 1 output float 2.000000 = [00,00,00,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 2 output float 3.000000 = [00,00,40,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 3 output float 4.000000 = [00,00,80,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 4 output float 5.000000 = [00,00,a0,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 5 output float 6.000000 = [00,00,c0,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 6 output float 7.000000 = [00,00,e0,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 7 output float 8.000000 = [00,00,00,41] difference=0.000000 TestCastFloat16 (EURUSD,H1) 8 output float 9.000000 = [00,00,10,41] difference=0.000000 TestCastFloat16 (EURUSD,H1) 9 output float 10.000000 = [00,00,20,41] difference=0.000000 TestCastFloat16 (EURUSD,H1) 10 output float 11.000000 = [00,00,30,41] difference=0.000000 TestCastFloat16 (EURUSD,H1) 11 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat16 (EURUSD,H1) test=RunCastFloat16ToFloat sum_error=0.000000
1.1.2. Testes de execução do operador ONNX Cast para BFLOAT16
Neste exemplo, considera-se a conversão do tipo BFLOAT16 para float.
Modelo ONNX com operação Cast:
Fig.4. Propriedades de entrada e saída do modelo test_cast_BFLOAT16_to_FLOAT.onnx
Os dados de entrada devem ser do tipo ONNX_DATA_TYPE_BFLOAT16, e os dados de saída serão retornados no formato ONNX_DATA_TYPE_FLOAT.
Para a conversão dos valores, utilizaremos as funções ArrayToFP16() e ArrayFromFP16() com o parâmetro BFLOAT_FP16.
//+------------------------------------------------------------------+ //| TestCastBFloat16.mq5 | //| Copyright 2024, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #property copyright "Copyright 2024, MetaQuotes Ltd." #property link "https://www.mql5.com" #property version "1.00" #resource "models\\test_cast_BFLOAT16_to_FLOAT.onnx" as const uchar ExtModel1[]; //+------------------------------------------------------------------+ //| union for data conversion | //+------------------------------------------------------------------+ template<typename T> union U { uchar uc[sizeof(T)]; T value; }; //+------------------------------------------------------------------+ //| ArrayToString | //+------------------------------------------------------------------+ template<typename T> string ArrayToString(const T &data[],uint length=16) { string res; for(uint n=0; n<MathMin(length,data.Size()); n++) res+="," + StringFormat("%.2x",data[n]); StringSetCharacter(res,0,'['); return res+"]"; } //+------------------------------------------------------------------+ //| PatchONNXModel | //+------------------------------------------------------------------+ void PatchONNXModel(const uchar &original_model[],uchar &patched_model[]) { ArrayCopy(patched_model,original_model,0,0,WHOLE_ARRAY); //--- special ONNX model patch(IR=9,Opset=20) patched_model[1]=0x09; patched_model[ArraySize(patched_model)-1]=0x14; } //+------------------------------------------------------------------+ //| CreateModel | //+------------------------------------------------------------------+ bool CreateModel(long &model_handle,const uchar &model[]) { model_handle=INVALID_HANDLE; ulong flags=ONNX_DEFAULT; //ulong flags=ONNX_DEBUG_LOGS; //--- model_handle=OnnxCreateFromBuffer(model,flags); if(model_handle==INVALID_HANDLE) return(false); //--- return(true); } //+------------------------------------------------------------------+ //| PrepareShapes | //+------------------------------------------------------------------+ bool PrepareShapes(long model_handle) { ulong input_shape1[]= {3,4}; if(!OnnxSetInputShape(model_handle,0,input_shape1)) { PrintFormat("error in OnnxSetInputShape for input1. error code=%d",GetLastError()); //-- OnnxRelease(model_handle); return(false); } //--- ulong output_shape[]= {3,4}; if(!OnnxSetOutputShape(model_handle,0,output_shape)) { PrintFormat("error in OnnxSetOutputShape for output. error code=%d",GetLastError()); //-- OnnxRelease(model_handle); return(false); } //--- return(true); } //+------------------------------------------------------------------+ //| RunCastBFloat16ToFloat | //+------------------------------------------------------------------+ bool RunCastBFloat16ToFloat(long model_handle) { PrintFormat("test=%s",__FUNCTION__); double test_data[12]= {1,2,3,4,5,6,7,8,9,10,11,12}; ushort data_uint16[12]; if(!ArrayToFP16(data_uint16,test_data,FLOAT_BFP16)) { Print("error in ArrayToFP16. error code=",GetLastError()); return(false); } Print("test array:"); ArrayPrint(test_data); Print("ArrayToFP16:"); ArrayPrint(data_uint16); U<ushort> input_float16_values[3*4]; U<float> output_float_values[3*4]; float test_data_float[]; if(!ArrayFromFP16(test_data_float,data_uint16,FLOAT_BFP16)) { Print("error in ArrayFromFP16. error code=",GetLastError()); return(false); } for(int i=0; i<12; i++) { input_float16_values[i].value=data_uint16[i]; PrintFormat("%d input value =%f Hex float16 = %s ushort value=%d",i,test_data_float[i],ArrayToString(input_float16_values[i].uc),input_float16_values[i].value); } Print("ONNX input array:"); ArrayPrint(input_float16_values); bool res=OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float16_values,output_float_values); if(!res) { PrintFormat("error in OnnxRun. error code=%d",GetLastError()); return(false); } Print("ONNX output array:"); ArrayPrint(output_float_values); //--- double sum_error=0.0; for(int i=0; i<12; i++) { double delta=test_data[i]-(double)output_float_values[i].value; sum_error+=MathAbs(delta); PrintFormat("%d output float %f = %s difference=%f",i,output_float_values[i].value,ArrayToString(output_float_values[i].uc),delta); } //--- PrintFormat("test=%s sum_error=%f",__FUNCTION__,sum_error); //--- return(true); } //+------------------------------------------------------------------+ //| Script program start function | //+------------------------------------------------------------------+ int OnStart(void) { uchar model[]; PatchONNXModel(ExtModel1,model); //--- get model handle long model_handle=INVALID_HANDLE; //--- get model handle if(!CreateModel(model_handle,model)) return 1; //--- prepare input and output shapes if(!PrepareShapes(model_handle)) return 1; //--- run ONNX model if(!RunCastBFloat16ToFloat(model_handle)) return 1; //--- release model handle OnnxRelease(model_handle); //--- return 0; } //+------------------------------------------------------------------+Resultado:
TestCastBFloat16 (EURUSD,H1) test=RunCastBFloat16ToFloat TestCastBFloat16 (EURUSD,H1) test array: TestCastBFloat16 (EURUSD,H1) 1.00000 2.00000 3.00000 4.00000 5.00000 6.00000 7.00000 8.00000 9.00000 10.00000 11.00000 12.00000 TestCastBFloat16 (EURUSD,H1) ArrayToFP16: TestCastBFloat16 (EURUSD,H1) 16256 16384 16448 16512 16544 16576 16608 16640 16656 16672 16688 16704 TestCastBFloat16 (EURUSD,H1) 0 input value =1.000000 Hex float16 = [80,3f] ushort value=16256 TestCastBFloat16 (EURUSD,H1) 1 input value =2.000000 Hex float16 = [00,40] ushort value=16384 TestCastBFloat16 (EURUSD,H1) 2 input value =3.000000 Hex float16 = [40,40] ushort value=16448 TestCastBFloat16 (EURUSD,H1) 3 input value =4.000000 Hex float16 = [80,40] ushort value=16512 TestCastBFloat16 (EURUSD,H1) 4 input value =5.000000 Hex float16 = [a0,40] ushort value=16544 TestCastBFloat16 (EURUSD,H1) 5 input value =6.000000 Hex float16 = [c0,40] ushort value=16576 TestCastBFloat16 (EURUSD,H1) 6 input value =7.000000 Hex float16 = [e0,40] ushort value=16608 TestCastBFloat16 (EURUSD,H1) 7 input value =8.000000 Hex float16 = [00,41] ushort value=16640 TestCastBFloat16 (EURUSD,H1) 8 input value =9.000000 Hex float16 = [10,41] ushort value=16656 TestCastBFloat16 (EURUSD,H1) 9 input value =10.000000 Hex float16 = [20,41] ushort value=16672 TestCastBFloat16 (EURUSD,H1) 10 input value =11.000000 Hex float16 = [30,41] ushort value=16688 TestCastBFloat16 (EURUSD,H1) 11 input value =12.000000 Hex float16 = [40,41] ushort value=16704 TestCastBFloat16 (EURUSD,H1) ONNX input array: TestCastBFloat16 (EURUSD,H1) [uc] [value] TestCastBFloat16 (EURUSD,H1) [ 0] ... 16256 TestCastBFloat16 (EURUSD,H1) [ 1] ... 16384 TestCastBFloat16 (EURUSD,H1) [ 2] ... 16448 TestCastBFloat16 (EURUSD,H1) [ 3] ... 16512 TestCastBFloat16 (EURUSD,H1) [ 4] ... 16544 TestCastBFloat16 (EURUSD,H1) [ 5] ... 16576 TestCastBFloat16 (EURUSD,H1) [ 6] ... 16608 TestCastBFloat16 (EURUSD,H1) [ 7] ... 16640 TestCastBFloat16 (EURUSD,H1) [ 8] ... 16656 TestCastBFloat16 (EURUSD,H1) [ 9] ... 16672 TestCastBFloat16 (EURUSD,H1) [10] ... 16688 TestCastBFloat16 (EURUSD,H1) [11] ... 16704 TestCastBFloat16 (EURUSD,H1) ONNX output array: TestCastBFloat16 (EURUSD,H1) [uc] [value] TestCastBFloat16 (EURUSD,H1) [ 0] ... 1.00000 TestCastBFloat16 (EURUSD,H1) [ 1] ... 2.00000 TestCastBFloat16 (EURUSD,H1) [ 2] ... 3.00000 TestCastBFloat16 (EURUSD,H1) [ 3] ... 4.00000 TestCastBFloat16 (EURUSD,H1) [ 4] ... 5.00000 TestCastBFloat16 (EURUSD,H1) [ 5] ... 6.00000 TestCastBFloat16 (EURUSD,H1) [ 6] ... 7.00000 TestCastBFloat16 (EURUSD,H1) [ 7] ... 8.00000 TestCastBFloat16 (EURUSD,H1) [ 8] ... 9.00000 TestCastBFloat16 (EURUSD,H1) [ 9] ... 10.00000 TestCastBFloat16 (EURUSD,H1) [10] ... 11.00000 TestCastBFloat16 (EURUSD,H1) [11] ... 12.00000 TestCastBFloat16 (EURUSD,H1) 0 output float 1.000000 = [00,00,80,3f] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 1 output float 2.000000 = [00,00,00,40] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 2 output float 3.000000 = [00,00,40,40] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 3 output float 4.000000 = [00,00,80,40] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 4 output float 5.000000 = [00,00,a0,40] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 5 output float 6.000000 = [00,00,c0,40] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 6 output float 7.000000 = [00,00,e0,40] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 7 output float 8.000000 = [00,00,00,41] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 8 output float 9.000000 = [00,00,10,41] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 9 output float 10.000000 = [00,00,20,41] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 10 output float 11.000000 = [00,00,30,41] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 11 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastBFloat16 (EURUSD,H1) test=RunCastBFloat16ToFloat sum_error=0.000000
1.2. Formato FP8
Modelos linguísticos modernos podem conter bilhões de parâmetros. O treinamento de modelos usando números FP16 já demonstrou sua eficiência. A transição de números de ponto flutuante de 16 bits para FP8 permite reduzir pela metade os requisitos de memória, além de acelerar o treinamento e a execução dos modelos.
O formato FP8 (número de ponto flutuante de 8 bits) é um dos tipos de dados usados para representar números de ponto flutuante. Em FP8, cada número é representado por 8 bits de dados, que geralmente são divididos em três componentes: sinal, expoente e mantissa. Este formato oferece um compromisso entre precisão e eficiência de armazenamento de dados, tornando-o atraente para uso em aplicações onde é necessário economizar memória e recursos computacionais.
Uma das principais vantagens do FP8 é sua eficiência no processamento de grandes volumes de dados. Devido à representação compacta dos números, o FP8 permite reduzir os requisitos de memória e acelerar os cálculos. Isso é especialmente importante em aplicações de aprendizado de máquina e inteligência artificial, onde o processamento de grandes conjuntos de dados é comum.
Além disso, o FP8 pode ser útil para a implementação de operações de baixo nível, como cálculos aritméticos e processamento de sinais. Seu formato compacto o torna adequado para uso em sistemas embarcados e aplicações onde os recursos são limitados. No entanto, é importante notar que o FP8 tem suas limitações relacionadas à sua precisão restrita. Em algumas aplicações, onde é necessária alta precisão nos cálculos, como em cálculos científicos ou análises financeiras, o uso do FP8 pode ser insuficiente.
1.2.1. Formatos fp8_e5m2 e fp8_e4m3
Em 2022, foram publicados dois artigos que introduzem números de ponto flutuante armazenados em um byte, ao contrário dos números float32, armazenados em 4 bytes.
No artigo FP8 Formats for Deep Learning (2022) da NVIDIA, Intel e ARM, são introduzidos dois tipos, seguindo as especificações IEEE. O primeiro tipo é o E4M3, 1 bit para o sinal, 4 bits para o expoente e 3 bits para a mantissa. O segundo tipo é o E5M2, 1 bit para o sinal, 5 bits para o expoente e 2 bits para a mantissa. O primeiro tipo é geralmente usado para pesos, o segundo para gradientes.
O segundo artigo "8-bit Numerical Formats For Deep Neural Networks" apresenta tipos semelhantes. O padrão IEEE atribui o mesmo valor a +0 (ou número inteiro 0) e -0 (ou número inteiro 128). O artigo propõe atribuir diferentes valores float a esses dois números. Além disso, são exploradas várias divisões entre expoente e mantissa, e mostrado que E4M3 e E5M2 são os melhores.
Como resultado, no ONNX (a partir da versão 1.15.0) foram introduzidos 4 novos tipos:
- E4M3FN: 1 bit para o sinal, 4 bits para o expoente, 3 bits para a mantissa, apenas valores NaN e sem valores infinitos (FN),
- E4M3FNUZ: 1 bit para o sinal, 4 bits para o expoente, 3 bits para a mantissa, apenas valores NaN e sem valores infinitos (FN), sem zero negativo (UZ)
- E5M2: 1 bit para o sinal, 5 bits para o expoente, 2 bits para a mantissa,
- E5M2FNUZ: 1 bit para o sinal, 5 bits para o expoente, 2 bits para a mantissa, apenas valores NaN e sem valores infinitos (FN), sem zero negativo (UZ)
A implementação geralmente depende do hardware. NVIDIA, Intel e Arm implementam E4M3FN, e E5M2 é implementado em processadores gráficos modernos. A GraphCore faz o mesmo apenas com E4M3FNUZ e E5M2FNUZ.
Vamos resumir as principais informações sobre o tipo FP8 conforme o artigo NVIDIA Hopper: H100 and FP8 Support.
Fig.5. Formato de representação em bits dos números de ponto flutuante FP8_E4M3
Tabl.3. Números de ponto flutuante no formato E5M2
Tabl.4. Números de ponto flutuante no formato E4M3
A comparação dos intervalos de valores positivos dos números FP8_E4M3 e FP8_E5M2 são mostrados na figura
Fig.6. Comparação dos intervalos de valores positivos dos números FP8 (fonte)
A comparação da precisão das operações aritméticas (Add, Mul, Div) para números nos formatos FP8_E5M2 e FP8_E4M3 são mostrados na figura:
Fig.7. Comparação da precisão das operações aritméticas para números nos formatos float8_e5m2 e float8_e4m3 (fonte)
Uso recomendado de números no formato FP8:
- E4M3 para tensores de pesos e ativação;
- E5M2 para tensores de gradientes.
1.2.2. Testes de execução do operador ONNX Cast para FLOAT8
Neste exemplo, considera-se a conversão de vários tipos FLOAT8 para float.
Modelos ONNX com operação Cast:
- https://github.com/onnx/onnx/tree/main/onnx/backend/test/data/node/test_cast_FLOAT8E4M3FN_to_FLOAT.onnx
- https://github.com/onnx/onnx/tree/main/onnx/backend/test/data/node/test_cast_FLOAT8E4M3FNUZ_to_FLOAT.onnx
- https://github.com/onnx/onnx/tree/main/onnx/backend/test/data/node/test_cast_FLOAT8E5M2_to_FLOAT.onnx
- https://github.com/onnx/onnx/tree/main/onnx/backend/test/data/node/test_cast_FLOAT8E5M2FNUZ_to_FLOAT.onnx
Fig.8. Parâmetros de entrada e saída do modelo test_cast_FLOAT8E4M3FN_to_FLOAT.onnx no MetaEditor
Fig.9. Parâmetros de entrada e saída do modelo test_cast_FLOAT8E4M3FNUZ_to_FLOAT.onnx no MetaEditor
Fig.10. Parâmetros de entrada e saída do modelo test_cast_FLOAT8E5M2_to_FLOAT.onnx no MetaEditor
Fig.11. Parâmetros de entrada e saída do modelo test_cast_FLOAT8E5M2FNUZ_to_FLOAT.onnx no MetaEditor
Exemplo:
//+------------------------------------------------------------------+ //| TestCastBFloat8.mq5 | //| Copyright 2024, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #property copyright "Copyright 2024, MetaQuotes Ltd." #property link "https://www.mql5.com" #property version "1.00" #resource "models\\test_cast_FLOAT8E4M3FN_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E4M3FN_to_FLOAT[]; #resource "models\\test_cast_FLOAT8E4M3FNUZ_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E4M3FNUZ_to_FLOAT[]; #resource "models\\test_cast_FLOAT8E5M2_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E5M2_to_FLOAT[]; #resource "models\\test_cast_FLOAT8E5M2FNUZ_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E5M2FNUZ_to_FLOAT[]; #define TEST_PASSED 0 #define TEST_FAILED 1 //+------------------------------------------------------------------+ //| union for data conversion | //+------------------------------------------------------------------+ template<typename T> union U { uchar uc[sizeof(T)]; T value; }; //+------------------------------------------------------------------+ //| ArrayToHexString | //+------------------------------------------------------------------+ template<typename T> string ArrayToHexString(const T &data[],uint length=16) { string res; for(uint n=0; n<MathMin(length,data.Size()); n++) res+="," + StringFormat("%.2x",data[n]); StringSetCharacter(res,0,'['); return(res+"]"); } //+------------------------------------------------------------------+ //| ArrayToString | //+------------------------------------------------------------------+ template<typename T> string ArrayToString(const U<T> &data[],uint length=16) { string res; for(uint n=0; n<MathMin(length,data.Size()); n++) res+="," + (string)data[n].value; StringSetCharacter(res,0,'['); return(res+"]"); } //+------------------------------------------------------------------+ //| PatchONNXModel | //+------------------------------------------------------------------+ long CreatePatchedModel(const uchar &original_model[]) { uchar patched_model[]; ArrayCopy(patched_model,original_model); //--- special ONNX model patch(IR=9,Opset=20) patched_model[1]=0x09; patched_model[ArraySize(patched_model)-1]=0x14; return(OnnxCreateFromBuffer(patched_model,ONNX_DEFAULT)); } //+------------------------------------------------------------------+ //| PrepareShapes | //+------------------------------------------------------------------+ bool PrepareShapes(long model_handle) { //--- configure input shape ulong input_shape[]= {3,5}; if(!OnnxSetInputShape(model_handle,0,input_shape)) { PrintFormat("error in OnnxSetInputShape for input1. error code=%d",GetLastError()); OnnxRelease(model_handle); return(false); } //--- configure output shape ulong output_shape[]= {3,5}; if(!OnnxSetOutputShape(model_handle,0,output_shape)) { PrintFormat("error in OnnxSetOutputShape for output. error code=%d",GetLastError()); OnnxRelease(model_handle); return(false); } return(true); } //+------------------------------------------------------------------+ //| RunCastFloat8Float | //+------------------------------------------------------------------+ bool RunCastFloat8ToFloat(long model_handle,const ENUM_FLOAT8_FORMAT fmt) { PrintFormat("TEST: %s(%s)",__FUNCTION__,EnumToString(fmt)); //--- float test_data[15] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}; uchar data_float8[15] = {}; if(!ArrayToFP8(data_float8,test_data,fmt)) { Print("error in ArrayToFP8. error code=",GetLastError()); OnnxRelease(model_handle); return(false); } U<uchar> input_float8_values[3*5]; U<float> output_float_values[3*5]; float test_data_float[]; //--- convert float8 to float if(!ArrayFromFP8(test_data_float,data_float8,fmt)) { Print("error in ArrayFromFP8. error code=",GetLastError()); OnnxRelease(model_handle); return(false); } for(uint i=0; i<data_float8.Size(); i++) { input_float8_values[i].value=data_float8[i]; PrintFormat("%d input value =%f Hex float8 = %s ushort value=%d",i,test_data_float[i],ArrayToHexString(input_float8_values[i].uc),input_float8_values[i].value); } Print("ONNX input array: ",ArrayToString(input_float8_values)); //--- execute model (convert float8 to float using ONNX) if(!OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float8_values,output_float_values)) { PrintFormat("error in OnnxRun. error code=%d",GetLastError()); OnnxRelease(model_handle); return(false); } Print("ONNX output array: ",ArrayToString(output_float_values)); //--- calculate error (compare ONNX and ArrayFromFP8 results) double sum_error=0.0; for(uint i=0; i<test_data.Size(); i++) { double delta=test_data_float[i]-(double)output_float_values[i].value; sum_error+=MathAbs(delta); PrintFormat("%d output float %f = %s difference=%f",i,output_float_values[i].value,ArrayToHexString(output_float_values[i].uc),delta); } //--- PrintFormat("%s(%s): sum_error=%f\n",__FUNCTION__,EnumToString(fmt),sum_error); return(true); } //+------------------------------------------------------------------+ //| TestModel | //+------------------------------------------------------------------+ bool TestModel(const uchar &model[],const ENUM_FLOAT8_FORMAT fmt) { //--- create patched model long model_handle=CreatePatchedModel(model); if(model_handle==INVALID_HANDLE) return(false); //--- prepare input and output shapes if(!PrepareShapes(model_handle)) return(false); //--- run ONNX model if(!RunCastFloat8ToFloat(model_handle,fmt)) return(false); //--- release model handle OnnxRelease(model_handle); return(true); } //+------------------------------------------------------------------+ //| Script program start function | //+------------------------------------------------------------------+ int OnStart(void) { //--- run ONNX model if(!TestModel(ExtModel_FLOAT8E4M3FN_to_FLOAT,FLOAT_FP8_E4M3FN)) return(TEST_FAILED); //--- run ONNX model if(!TestModel(ExtModel_FLOAT8E4M3FNUZ_to_FLOAT,FLOAT_FP8_E4M3FNUZ)) return(TEST_FAILED); //--- run ONNX model if(!TestModel(ExtModel_FLOAT8E5M2_to_FLOAT,FLOAT_FP8_E5M2FN)) return(TEST_FAILED); //--- run ONNX model if(!TestModel(ExtModel_FLOAT8E5M2FNUZ_to_FLOAT,FLOAT_FP8_E5M2FNUZ)) return(TEST_FAILED); return(TEST_PASSED); } //+------------------------------------------------------------------+
Resultado:
TestCastFloat8 (EURUSD,H1) TEST: RunCastFloat8ToFloat(FLOAT_FP8_E4M3FN) TestCastFloat8 (EURUSD,H1) 0 input value =1.000000 Hex float8 = [38] ushort value=56 TestCastFloat8 (EURUSD,H1) 1 input value =2.000000 Hex float8 = [40] ushort value=64 TestCastFloat8 (EURUSD,H1) 2 input value =3.000000 Hex float8 = [44] ushort value=68 TestCastFloat8 (EURUSD,H1) 3 input value =4.000000 Hex float8 = [48] ushort value=72 TestCastFloat8 (EURUSD,H1) 4 input value =5.000000 Hex float8 = [4a] ushort value=74 TestCastFloat8 (EURUSD,H1) 5 input value =6.000000 Hex float8 = [4c] ushort value=76 TestCastFloat8 (EURUSD,H1) 6 input value =7.000000 Hex float8 = [4e] ushort value=78 TestCastFloat8 (EURUSD,H1) 7 input value =8.000000 Hex float8 = [50] ushort value=80 TestCastFloat8 (EURUSD,H1) 8 input value =9.000000 Hex float8 = [51] ushort value=81 TestCastFloat8 (EURUSD,H1) 9 input value =10.000000 Hex float8 = [52] ushort value=82 TestCastFloat8 (EURUSD,H1) 10 input value =11.000000 Hex float8 = [53] ushort value=83 TestCastFloat8 (EURUSD,H1) 11 input value =12.000000 Hex float8 = [54] ushort value=84 TestCastFloat8 (EURUSD,H1) 12 input value =13.000000 Hex float8 = [55] ushort value=85 TestCastFloat8 (EURUSD,H1) 13 input value =14.000000 Hex float8 = [56] ushort value=86 TestCastFloat8 (EURUSD,H1) 14 input value =15.000000 Hex float8 = [57] ushort value=87 TestCastFloat8 (EURUSD,H1) ONNX input array: [56,64,68,72,74,76,78,80,81,82,83,84,85,86,87] TestCastFloat8 (EURUSD,H1) ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0] TestCastFloat8 (EURUSD,H1) 0 output float 1.000000 = [00,00,80,3f] difference=0.000000 TestCastFloat8 (EURUSD,H1) 1 output float 2.000000 = [00,00,00,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 2 output float 3.000000 = [00,00,40,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 3 output float 4.000000 = [00,00,80,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 4 output float 5.000000 = [00,00,a0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 5 output float 6.000000 = [00,00,c0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 6 output float 7.000000 = [00,00,e0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 7 output float 8.000000 = [00,00,00,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 8 output float 9.000000 = [00,00,10,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 9 output float 10.000000 = [00,00,20,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 10 output float 11.000000 = [00,00,30,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 11 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 12 output float 13.000000 = [00,00,50,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 13 output float 14.000000 = [00,00,60,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 14 output float 15.000000 = [00,00,70,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) RunCastFloat8ToFloat(FLOAT_FP8_E4M3FN): sum_error=0.000000 TestCastFloat8 (EURUSD,H1) TestCastFloat8 (EURUSD,H1) TEST: RunCastFloat8ToFloat(FLOAT_FP8_E4M3FNUZ) TestCastFloat8 (EURUSD,H1) 0 input value =1.000000 Hex float8 = [40] ushort value=64 TestCastFloat8 (EURUSD,H1) 1 input value =2.000000 Hex float8 = [48] ushort value=72 TestCastFloat8 (EURUSD,H1) 2 input value =3.000000 Hex float8 = [4c] ushort value=76 TestCastFloat8 (EURUSD,H1) 3 input value =4.000000 Hex float8 = [50] ushort value=80 TestCastFloat8 (EURUSD,H1) 4 input value =5.000000 Hex float8 = [52] ushort value=82 TestCastFloat8 (EURUSD,H1) 5 input value =6.000000 Hex float8 = [54] ushort value=84 TestCastFloat8 (EURUSD,H1) 6 input value =7.000000 Hex float8 = [56] ushort value=86 TestCastFloat8 (EURUSD,H1) 7 input value =8.000000 Hex float8 = [58] ushort value=88 TestCastFloat8 (EURUSD,H1) 8 input value =9.000000 Hex float8 = [59] ushort value=89 TestCastFloat8 (EURUSD,H1) 9 input value =10.000000 Hex float8 = [5a] ushort value=90 TestCastFloat8 (EURUSD,H1) 10 input value =11.000000 Hex float8 = [5b] ushort value=91 TestCastFloat8 (EURUSD,H1) 11 input value =12.000000 Hex float8 = [5c] ushort value=92 TestCastFloat8 (EURUSD,H1) 12 input value =13.000000 Hex float8 = [5d] ushort value=93 TestCastFloat8 (EURUSD,H1) 13 input value =14.000000 Hex float8 = [5e] ushort value=94 TestCastFloat8 (EURUSD,H1) 14 input value =15.000000 Hex float8 = [5f] ushort value=95 TestCastFloat8 (EURUSD,H1) ONNX input array: [64,72,76,80,82,84,86,88,89,90,91,92,93,94,95] TestCastFloat8 (EURUSD,H1) ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0] TestCastFloat8 (EURUSD,H1) 0 output float 1.000000 = [00,00,80,3f] difference=0.000000 TestCastFloat8 (EURUSD,H1) 1 output float 2.000000 = [00,00,00,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 2 output float 3.000000 = [00,00,40,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 3 output float 4.000000 = [00,00,80,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 4 output float 5.000000 = [00,00,a0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 5 output float 6.000000 = [00,00,c0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 6 output float 7.000000 = [00,00,e0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 7 output float 8.000000 = [00,00,00,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 8 output float 9.000000 = [00,00,10,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 9 output float 10.000000 = [00,00,20,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 10 output float 11.000000 = [00,00,30,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 11 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 12 output float 13.000000 = [00,00,50,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 13 output float 14.000000 = [00,00,60,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 14 output float 15.000000 = [00,00,70,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) RunCastFloat8ToFloat(FLOAT_FP8_E4M3FNUZ): sum_error=0.000000 TestCastFloat8 (EURUSD,H1) TestCastFloat8 (EURUSD,H1) TEST: RunCastFloat8ToFloat(FLOAT_FP8_E5M2FN) TestCastFloat8 (EURUSD,H1) 0 input value =1.000000 Hex float8 = [3c] ushort value=60 TestCastFloat8 (EURUSD,H1) 1 input value =2.000000 Hex float8 = [40] ushort value=64 TestCastFloat8 (EURUSD,H1) 2 input value =3.000000 Hex float8 = [42] ushort value=66 TestCastFloat8 (EURUSD,H1) 3 input value =4.000000 Hex float8 = [44] ushort value=68 TestCastFloat8 (EURUSD,H1) 4 input value =5.000000 Hex float8 = [45] ushort value=69 TestCastFloat8 (EURUSD,H1) 5 input value =6.000000 Hex float8 = [46] ushort value=70 TestCastFloat8 (EURUSD,H1) 6 input value =7.000000 Hex float8 = [47] ushort value=71 TestCastFloat8 (EURUSD,H1) 7 input value =8.000000 Hex float8 = [48] ushort value=72 TestCastFloat8 (EURUSD,H1) 8 input value =8.000000 Hex float8 = [48] ushort value=72 TestCastFloat8 (EURUSD,H1) 9 input value =10.000000 Hex float8 = [49] ushort value=73 TestCastFloat8 (EURUSD,H1) 10 input value =12.000000 Hex float8 = [4a] ushort value=74 TestCastFloat8 (EURUSD,H1) 11 input value =12.000000 Hex float8 = [4a] ushort value=74 TestCastFloat8 (EURUSD,H1) 12 input value =12.000000 Hex float8 = [4a] ushort value=74 TestCastFloat8 (EURUSD,H1) 13 input value =14.000000 Hex float8 = [4b] ushort value=75 TestCastFloat8 (EURUSD,H1) 14 input value =16.000000 Hex float8 = [4c] ushort value=76 TestCastFloat8 (EURUSD,H1) ONNX input array: [60,64,66,68,69,70,71,72,72,73,74,74,74,75,76] TestCastFloat8 (EURUSD,H1) ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,8.0,10.0,12.0,12.0,12.0,14.0,16.0] TestCastFloat8 (EURUSD,H1) 0 output float 1.000000 = [00,00,80,3f] difference=0.000000 TestCastFloat8 (EURUSD,H1) 1 output float 2.000000 = [00,00,00,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 2 output float 3.000000 = [00,00,40,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 3 output float 4.000000 = [00,00,80,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 4 output float 5.000000 = [00,00,a0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 5 output float 6.000000 = [00,00,c0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 6 output float 7.000000 = [00,00,e0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 7 output float 8.000000 = [00,00,00,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 8 output float 8.000000 = [00,00,00,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 9 output float 10.000000 = [00,00,20,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 10 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 11 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 12 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 13 output float 14.000000 = [00,00,60,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 14 output float 16.000000 = [00,00,80,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) RunCastFloat8ToFloat(FLOAT_FP8_E5M2FN): sum_error=0.000000 TestCastFloat8 (EURUSD,H1) TestCastFloat8 (EURUSD,H1) TEST: RunCastFloat8ToFloat(FLOAT_FP8_E5M2FNUZ) TestCastFloat8 (EURUSD,H1) 0 input value =1.000000 Hex float8 = [40] ushort value=64 TestCastFloat8 (EURUSD,H1) 1 input value =2.000000 Hex float8 = [44] ushort value=68 TestCastFloat8 (EURUSD,H1) 2 input value =3.000000 Hex float8 = [46] ushort value=70 TestCastFloat8 (EURUSD,H1) 3 input value =4.000000 Hex float8 = [48] ushort value=72 TestCastFloat8 (EURUSD,H1) 4 input value =5.000000 Hex float8 = [49] ushort value=73 TestCastFloat8 (EURUSD,H1) 5 input value =6.000000 Hex float8 = [4a] ushort value=74 TestCastFloat8 (EURUSD,H1) 6 input value =7.000000 Hex float8 = [4b] ushort value=75 TestCastFloat8 (EURUSD,H1) 7 input value =8.000000 Hex float8 = [4c] ushort value=76 TestCastFloat8 (EURUSD,H1) 8 input value =8.000000 Hex float8 = [4c] ushort value=76 TestCastFloat8 (EURUSD,H1) 9 input value =10.000000 Hex float8 = [4d] ushort value=77 TestCastFloat8 (EURUSD,H1) 10 input value =12.000000 Hex float8 = [4e] ushort value=78 TestCastFloat8 (EURUSD,H1) 11 input value =12.000000 Hex float8 = [4e] ushort value=78 TestCastFloat8 (EURUSD,H1) 12 input value =12.000000 Hex float8 = [4e] ushort value=78 TestCastFloat8 (EURUSD,H1) 13 input value =14.000000 Hex float8 = [4f] ushort value=79 TestCastFloat8 (EURUSD,H1) 14 input value =16.000000 Hex float8 = [50] ushort value=80 TestCastFloat8 (EURUSD,H1) ONNX input array: [64,68,70,72,73,74,75,76,76,77,78,78,78,79,80] TestCastFloat8 (EURUSD,H1) ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,8.0,10.0,12.0,12.0,12.0,14.0,16.0] TestCastFloat8 (EURUSD,H1) 0 output float 1.000000 = [00,00,80,3f] difference=0.000000 TestCastFloat8 (EURUSD,H1) 1 output float 2.000000 = [00,00,00,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 2 output float 3.000000 = [00,00,40,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 3 output float 4.000000 = [00,00,80,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 4 output float 5.000000 = [00,00,a0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 5 output float 6.000000 = [00,00,c0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 6 output float 7.000000 = [00,00,e0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 7 output float 8.000000 = [00,00,00,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 8 output float 8.000000 = [00,00,00,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 9 output float 10.000000 = [00,00,20,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 10 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 11 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 12 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 13 output float 14.000000 = [00,00,60,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 14 output float 16.000000 = [00,00,80,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) RunCastFloat8ToFloat(FLOAT_FP8_E5M2FNUZ): sum_error=0.000000 TestCastFloat8 (EURUSD,H1)
2. Exemplo de uso do ONNX para aumentar a resolução de imagens
Nesta seção, vamos considerar um exemplo de uso dos modelos ESRGAN para aumentar a resolução de imagens.
ESRGAN, ou Enhanced Super-Resolution Generative Adversarial Networks, representa uma arquitetura poderosa de redes neurais projetada para a tarefa de super-resolução de imagens. O ESRGAN foi desenvolvido com o objetivo de melhorar a qualidade das imagens, aumentando sua resolução para um nível superior. Isso é alcançado treinando uma rede neural profunda em um grande conjunto de dados de baixa resolução e suas imagens correspondentes de alta qualidade. O ESRGAN utiliza a arquitetura de redes generativas adversariais (GANs), que consiste em dois componentes principais: gerador e discriminador. O gerador é responsável por criar imagens de alta resolução, enquanto o discriminador é treinado para distinguir as imagens geradas das reais.
No núcleo da arquitetura ESRGAN estão os blocos residuais, que ajudam a extrair e preservar características importantes das imagens em diferentes níveis de abstração. Isso permite que a rede recupere detalhes e texturas com alta qualidade.
Para alcançar alta qualidade e generalidade na tarefa de super-resolução, o ESRGAN requer extensos conjuntos de dados para treinamento. Isso permite que a rede aprenda diferentes estilos e características de imagens, tornando-a mais adaptável a diversos tipos de dados de entrada. O ESRGAN pode ser usado para melhorar a qualidade das imagens em várias áreas, incluindo fotografia, diagnóstico médico, produção de filmes e vídeos, design gráfico e muito mais. Sua flexibilidade e eficiência o tornam um dos métodos líderes na área de super-resolução de imagens.
O ESRGAN representa um avanço significativo na área de processamento de imagens e inteligência artificial, abrindo novas possibilidades para a criação e melhoria de imagens.
2.1. Exemplo de execução do modelo ONNX com float32
Para a execução do exemplo, é necessário baixar o arquivo https://github.com/amannm/super-resolution-service/blob/main/models/esrgan.onnx e copiá-lo na pasta \MQL5\Scripts\models.
O modelo ESRGAN.onnx contém ~1200 operações ONNX, algumas das quais são apresentadas na Fig.12
Fig.12. Modelo ESRGAN no MetaEditor
Fig.13. Modelo ESRGAN no Netron
Ele começa com o carregamento do modelo esrgan.onnx, em seguida, a imagem original em formato BMP é selecionada e carregada. Depois disso, a imagem é convertida em canais RGB separados, que são alimentados na entrada do modelo. O modelo realiza o processo de aumento do tamanho da imagem em 4 vezes, após o qual a imagem ampliada passa por uma transformação inversa e é preparada para exibição.
Para exibição, é utilizada a biblioteca Canvas, e para a execução do modelo, a biblioteca ONNX Runtime. Após a execução do programa, a imagem ampliada é salva em um arquivo com "_upscaled" adicionado ao nome do arquivo original. As principais funções incluem a pré-processamento e pós-processamento da imagem, bem como a execução do modelo para aumento do tamanho da imagem.
//+------------------------------------------------------------------+ //| ESRGAN.mq5 | //| Copyright 2024, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #property copyright "Copyright 2024, MetaQuotes Ltd." #property link "https://www.mql5.com" #property version "1.00" //+------------------------------------------------------------------+ //| 4x image upscaling demo using ESRGAN | //| esrgan.onnx model from | //| https://github.com/amannm/super-resolution-service/ | //+------------------------------------------------------------------+ //| Xintao Wang et al (2018) | //| ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks| //| https://arxiv.org/abs/1809.00219 | //+------------------------------------------------------------------+ #resource "models\\esrgan.onnx" as uchar ExtModel[]; #include <Canvas\Canvas.mqh> //+------------------------------------------------------------------+ //| clamp | //+------------------------------------------------------------------+ float clamp(float value, float minValue, float maxValue) { return MathMin(MathMax(value, minValue), maxValue); } //+------------------------------------------------------------------+ //| Preprocessing | //+------------------------------------------------------------------+ bool Preprocessing(float &data[],uint &image_data[],int &image_width,int &image_height) { //--- checkup if(image_height==0 || image_width==0) return(false); //--- prepare destination array with separated RGB channels for ONNX model int data_count=3*image_width*image_height; if(ArrayResize(data,data_count)!=data_count) { Print("ArrayResize failed"); return(false); } //--- converting for(int y=0; y<image_height; y++) for(int x=0; x<image_width; x++) { //--- load source RGB int offset=y*image_width+x; uint clr =image_data[offset]; uchar r =GETRGBR(clr); uchar g =GETRGBG(clr); uchar b =GETRGBB(clr); //--- store RGB components as separated channels int offset_ch1=0*image_width*image_height+offset; int offset_ch2=1*image_width*image_height+offset; int offset_ch3=2*image_width*image_height+offset; data[offset_ch1]=r/255.0f; data[offset_ch2]=g/255.0f; data[offset_ch3]=b/255.0f; } //--- return(true); } //+------------------------------------------------------------------+ //| PostProcessing | //+------------------------------------------------------------------+ bool PostProcessing(const float &data[], uint &image_data[], const int &image_width, const int &image_height) { //--- checks if(image_height == 0 || image_width == 0) return(false); int data_count=image_width*image_height; if(ArraySize(data)!=3*data_count) return(false); if(ArrayResize(image_data,data_count)!=data_count) return(false); //--- for(int y=0; y<image_height; y++) for(int x=0; x<image_width; x++) { int offset =y*image_width+x; int offset_ch1=0*image_width*image_height+offset; int offset_ch2=1*image_width*image_height+offset; int offset_ch3=2*image_width*image_height+offset; //--- rescale to [0..255] float r=clamp(data[offset_ch1]*255,0,255); float g=clamp(data[offset_ch2]*255,0,255); float b=clamp(data[offset_ch3]*255,0,255); //--- set color image_data image_data[offset]=XRGB(uchar(r),uchar(g),uchar(b)); } //--- return(true); } //+------------------------------------------------------------------+ //| ShowImage | //+------------------------------------------------------------------+ bool ShowImage(CCanvas &canvas,const string name,const int x0,const int y0,const int image_width,const int image_height, const uint &image_data[]) { if(ArraySize(image_data)==0 || name=="") return(false); //--- prepare canvas canvas.CreateBitmapLabel(name,x0,y0,image_width,image_height,COLOR_FORMAT_XRGB_NOALPHA); //--- copy image to canvas for(int y=0; y<image_height; y++) for(int x=0; x<image_width; x++) canvas.PixelSet(x,y,image_data[y*image_width+x]); //--- ready to draw canvas.Update(true); return(true); } //+------------------------------------------------------------------+ //| Script program start function | //+------------------------------------------------------------------+ int OnStart(void) { //--- select BMP from <data folder>\MQL5\Files string image_path[1]; if(FileSelectDialog("Select BMP image",NULL,"Bitmap files (*.bmp)|*.bmp",FSD_FILE_MUST_EXIST,image_path,"lenna-original4.bmp")!=1) { Print("file not selected"); return(-1); } //--- load BMP into array uint image_data[]; int image_width; int image_height; if(!CCanvas::LoadBitmap(image_path[0],image_data,image_width,image_height)) { PrintFormat("CCanvas::LoadBitmap failed with error %d",GetLastError()); return(-1); } //--- convert RGB image to separated RGB channels float input_data[]; Preprocessing(input_data,image_data,image_width,image_height); PrintFormat("input array size=%d",ArraySize(input_data)); //--- load model long model_handle=OnnxCreateFromBuffer(ExtModel,ONNX_DEFAULT); if(model_handle==INVALID_HANDLE) { PrintFormat("OnnxCreate error %d",GetLastError()); return(-1); } PrintFormat("model loaded successfully"); PrintFormat("original: width=%d, height=%d Size=%d",image_width,image_height,ArraySize(image_data)); //--- set input shape ulong input_shape[]={1,3,image_height,image_width}; if(!OnnxSetInputShape(model_handle,0,input_shape)) { PrintFormat("error in OnnxSetInputShape. error code=%d",GetLastError()); OnnxRelease(model_handle); return(-1); } //--- upscaled image size int new_image_width =4*image_width; int new_image_height=4*image_height; ulong output_shape[]= {1,3,new_image_height,new_image_width}; if(!OnnxSetOutputShape(model_handle,0,output_shape)) { PrintFormat("error in OnnxSetOutputShape. error code=%d",GetLastError()); OnnxRelease(model_handle); return(-1); } //--- run the model float output_data[]; int new_data_count=3*new_image_width*new_image_height; if(ArrayResize(output_data,new_data_count)!=new_data_count) { OnnxRelease(model_handle); return(-1); } if(!OnnxRun(model_handle,ONNX_DEBUG_LOGS,input_data,output_data)) { PrintFormat("error in OnnxRun. error code=%d",GetLastError()); OnnxRelease(model_handle); return(-1); } Print("model successfully executed, output data size ",ArraySize(output_data)); OnnxRelease(model_handle); //--- postprocessing uint new_image[]; PostProcessing(output_data,new_image,new_image_width,new_image_height); //--- show images CCanvas canvas_original,canvas_scaled; ShowImage(canvas_original,"original_image",new_image_width,0,image_width,image_height,image_data); ShowImage(canvas_scaled,"upscaled_image",0,0,new_image_width,new_image_height,new_image); //--- save upscaled image StringReplace(image_path[0],".bmp","_upscaled.bmp"); Print(ResourceSave(canvas_scaled.ResourceName(),image_path[0])); //--- while(!IsStopped()) Sleep(100); return(0); } //+------------------------------------------------------------------+
Resultado:
Fig.14. Resultado do modelo ESRGAN.onnx (160x200 -> 640x800)
Neste exemplo, uma imagem de 160x200 foi ampliada em 4 vezes (para 640x800) usando o modelo ESRGAN.onnx.
2.2. Exemplo de execução do modelo ONNX com float16
Para converter modelos em float16, utilizaremos o método descrito em Create Float16 and Mixed Precision Models.
# Copyright 2024, MetaQuotes Ltd. # https://www.mql5.com import onnx from onnxconverter_common import float16 from sys import argv # Define the path for saving the model data_path = argv[0] last_index = data_path.rfind("\\") + 1 data_path = data_path[0:last_index] # конвертация модели в float16 model_path = data_path+'\\models\\esrgan.onnx' modelfp16_path = data_path+'\\models\\esrgan_float16.onnx' model = onnx.load(model_path) model_fp16 = float16.convert_float_to_float16(model) onnx.save(model_fp16, modelfp16_path)
Após a conversão, o tamanho do arquivo foi reduzido pela metade (de 64MB para 32MB).
As alterações no código são mínimas:
//+------------------------------------------------------------------+ //| ESRGAN_float16.mq5 | //| Copyright 2024, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #property copyright "Copyright 2024, MetaQuotes Ltd." #property link "https://www.mql5.com" #property version "1.00" //+------------------------------------------------------------------+ //| 4x image upscaling demo using ESRGAN | //| esrgan.onnx model from | //| https://github.com/amannm/super-resolution-service/ | //+------------------------------------------------------------------+ //| Xintao Wang et al (2018) | //| ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks| //| https://arxiv.org/abs/1809.00219 | //+------------------------------------------------------------------+ #resource "models\\esrgan_float16.onnx" as uchar ExtModel[]; #include <Canvas\Canvas.mqh> //+------------------------------------------------------------------+ //| clamp | //+------------------------------------------------------------------+ float clamp(float value, float minValue, float maxValue) { return MathMin(MathMax(value, minValue), maxValue); } //+------------------------------------------------------------------+ //| Preprocessing | //+------------------------------------------------------------------+ bool Preprocessing(float &data[],uint &image_data[],int &image_width,int &image_height) { //--- checkup if(image_height==0 || image_width==0) return(false); //--- prepare destination array with separated RGB channels for ONNX model int data_count=3*image_width*image_height; if(ArrayResize(data,data_count)!=data_count) { Print("ArrayResize failed"); return(false); } //--- converting for(int y=0; y<image_height; y++) for(int x=0; x<image_width; x++) { //--- load source RGB int offset=y*image_width+x; uint clr =image_data[offset]; uchar r =GETRGBR(clr); uchar g =GETRGBG(clr); uchar b =GETRGBB(clr); //--- store RGB components as separated channels int offset_ch1=0*image_width*image_height+offset; int offset_ch2=1*image_width*image_height+offset; int offset_ch3=2*image_width*image_height+offset; data[offset_ch1]=r/255.0f; data[offset_ch2]=g/255.0f; data[offset_ch3]=b/255.0f; } //--- return(true); } //+------------------------------------------------------------------+ //| PostProcessing | //+------------------------------------------------------------------+ bool PostProcessing(const float &data[], uint &image_data[], const int &image_width, const int &image_height) { //--- checks if(image_height == 0 || image_width == 0) return(false); int data_count=image_width*image_height; if(ArraySize(data)!=3*data_count) return(false); if(ArrayResize(image_data,data_count)!=data_count) return(false); //--- for(int y=0; y<image_height; y++) for(int x=0; x<image_width; x++) { int offset =y*image_width+x; int offset_ch1=0*image_width*image_height+offset; int offset_ch2=1*image_width*image_height+offset; int offset_ch3=2*image_width*image_height+offset; //--- rescale to [0..255] float r=clamp(data[offset_ch1]*255,0,255); float g=clamp(data[offset_ch2]*255,0,255); float b=clamp(data[offset_ch3]*255,0,255); //--- set color image_data image_data[offset]=XRGB(uchar(r),uchar(g),uchar(b)); } //--- return(true); } //+------------------------------------------------------------------+ //| ShowImage | //+------------------------------------------------------------------+ bool ShowImage(CCanvas &canvas,const string name,const int x0,const int y0,const int image_width,const int image_height, const uint &image_data[]) { if(ArraySize(image_data)==0 || name=="") return(false); //--- prepare canvas canvas.CreateBitmapLabel(name,x0,y0,image_width,image_height,COLOR_FORMAT_XRGB_NOALPHA); //--- copy image to canvas for(int y=0; y<image_height; y++) for(int x=0; x<image_width; x++) canvas.PixelSet(x,y,image_data[y*image_width+x]); //--- ready to draw canvas.Update(true); return(true); } //+------------------------------------------------------------------+ //| Script program start function | //+------------------------------------------------------------------+ int OnStart(void) { //--- select BMP from <data folder>\MQL5\Files string image_path[1]; if(FileSelectDialog("Select BMP image",NULL,"Bitmap files (*.bmp)|*.bmp",FSD_FILE_MUST_EXIST,image_path,"lenna.bmp")!=1) { Print("file not selected"); return(-1); } //--- load BMP into array uint image_data[]; int image_width; int image_height; if(!CCanvas::LoadBitmap(image_path[0],image_data,image_width,image_height)) { PrintFormat("CCanvas::LoadBitmap failed with error %d",GetLastError()); return(-1); } //--- convert RGB image to separated RGB channels float input_data[]; Preprocessing(input_data,image_data,image_width,image_height); PrintFormat("input array size=%d",ArraySize(input_data)); ushort input_data_float16[]; if(!ArrayToFP16(input_data_float16,input_data,FLOAT_FP16)) { Print("error in ArrayToFP16. error code=",GetLastError()); return(false); } //--- load model long model_handle=OnnxCreateFromBuffer(ExtModel,ONNX_DEFAULT); if(model_handle==INVALID_HANDLE) { PrintFormat("OnnxCreate error %d",GetLastError()); return(-1); } PrintFormat("model loaded successfully"); PrintFormat("original: width=%d, height=%d Size=%d",image_width,image_height,ArraySize(image_data)); //--- set input shape ulong input_shape[]={1,3,image_height,image_width}; if(!OnnxSetInputShape(model_handle,0,input_shape)) { PrintFormat("error in OnnxSetInputShape. error code=%d",GetLastError()); OnnxRelease(model_handle); return(-1); } //--- upscaled image size int new_image_width =4*image_width; int new_image_height=4*image_height; ulong output_shape[]= {1,3,new_image_height,new_image_width}; if(!OnnxSetOutputShape(model_handle,0,output_shape)) { PrintFormat("error in OnnxSetOutputShape. error code=%d",GetLastError()); OnnxRelease(model_handle); return(-1); } //--- run the model float output_data[]; ushort output_data_float16[]; int new_data_count=3*new_image_width*new_image_height; if(ArrayResize(output_data_float16,new_data_count)!=new_data_count) { OnnxRelease(model_handle); return(-1); } if(!OnnxRun(model_handle,ONNX_NO_CONVERSION,input_data_float16,output_data_float16)) { PrintFormat("error in OnnxRun. error code=%d",GetLastError()); OnnxRelease(model_handle); return(-1); } Print("model successfully executed, output data size ",ArraySize(output_data)); OnnxRelease(model_handle); if(!ArrayFromFP16(output_data,output_data_float16,FLOAT_FP16)) { Print("error in ArrayFromFP16. error code=",GetLastError()); return(false); } //--- postprocessing uint new_image[]; PostProcessing(output_data,new_image,new_image_width,new_image_height); //--- show images CCanvas canvas_original,canvas_scaled; ShowImage(canvas_original,"original_image",new_image_width,0,image_width,image_height,image_data); ShowImage(canvas_scaled,"upscaled_image",0,0,new_image_width,new_image_height,new_image); //--- save upscaled image StringReplace(image_path[0],".bmp","_upscaled.bmp"); Print(ResourceSave(canvas_scaled.ResourceName(),image_path[0])); //--- while(!IsStopped()) Sleep(100); return(0); } //+------------------------------------------------------------------+
As alterações no código necessárias para a execução do modelo convertido para o formato float16 estão destacadas em cores.
Resultado:
Fig.15. Resultado do modelo ESRGAN_float16.onnx (160x200 -> 640x800)
Assim, o uso de números float16 em vez de float32 permite reduzir o tamanho do arquivo do modelo ONNX pela metade (de 64MB para 32MB).
Ao executar modelos com números float16, a qualidade das imagens permaneceu a mesma, sendo visualmente difícil encontrar diferenças:
Fig.16. Comparação dos resultados do modelo ESRGAN para float e float16
As alterações no código são mínimas, bastando apenas cuidar da conversão dos dados de entrada e saída.
Neste caso, após a conversão para float16, a qualidade do funcionamento do modelo não mudou significativamente, no entanto, ao analisar dados financeiros, deve-se buscar cálculos com a máxima precisão possível.
Conclusões
O uso de novos tipos de dados para números de ponto flutuante permite reduzir o tamanho dos modelos ONNX sem perda significativa de qualidade.
O pré-processamento e pós-processamento de dados são significativamente simplificados pelo uso das funções de conversão de dados ArrayToFP16/ArrayFromFP16 e ArrayToFP8/ArrayFromFP8.
Para trabalhar com modelos ONNX convertidos, são necessárias mudanças mínimas no código.
Traduzido do russo pela MetaQuotes Ltd.
Artigo original: https://www.mql5.com/ru/articles/14330
- Aplicativos de negociação gratuitos
- 8 000+ sinais para cópia
- Notícias econômicas para análise dos mercados financeiros
Você concorda com a política do site e com os termos de uso