
float16およびfloat8形式のONNXモデルを扱う
内容
- 1. ONNXモデルを扱うための新しいデータ型
- 1.1. FP16形式
- 1.1.1. FLOAT16に対するONNXキャスト演算子の実行テスト
- 1.1.2. BFLOAT16に対するONNXキャスト演算子の実行テスト
- 1.2. FP8形式
- 1.2.1. fp8_e5m2およびfp8_e4m3 FP8形式
- 1.2.2. FLOAT8に対するONNXキャスト演算子の実行テスト
- 2. 画像の高解析度化でのONNXの使用
- 2.1. float32を使用したONNXモデルの実行
- 2.2.float16を使用したONNXモデルの実行
- 終わりに
機械学習と人工知能技術の進歩に伴い、モデルを扱うプロセスを最適化する必要性が高まっています。モデル操作の効率は、それを表現するために使用されるデータ形式に直接依存します。近年、深層学習モデルを扱うために特別に設計された新しいデータ型がいくつか登場しています。
この記事では、最新のONNXモデルで積極的に使われ始めているfloat16とfloat8という2つの新しいデータ形式に焦点を当てます。これらの形式は、より正確だがリソースを大量に消費する浮動小数点データ形式の代替オプションとなっています。また、パフォーマンスと精度の最適なバランスを提供し、さまざまな機械学習タスクに特に魅力的なものとなっています。float16形式とfloat8形式の主な特徴と利点を探り、それらを標準的なfloatとdouble形式に変換する関数を紹介します。
これは、開発者や研究者が、プロジェクトやモデルにおいてこれらの形式を効果的に使用する方法をよりよく理解するのに役立つでしょう。一例として、画質補正に使用されるESRGAN ONNXモデルの動作を検証します。
1. ONNXモデルを扱うための新しいデータ型
計算を高速化するために、一部のモデルではFloat16やFloat8のような精度の低いデータ型が利用されています。
MQL5言語では、これらの新しいデータ型をONNXモデルで使用するサポートが追加され、8ビットおよび16ビットの浮動小数点表現を操作できるようになっています。
次のスクリプトは、ENUM_ONNX_DATA_TYPE列挙の要素の完全なリストを出力します。
//+------------------------------------------------------------------+ //| ONNX_Data_Types.mq5 | //| Copyright 2024, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #property copyright "Copyright 2024, MetaQuotes Ltd." #property link "https://www.mql5.com" #property version "1.00" //+------------------------------------------------------------------+ //| Script program start function | //+------------------------------------------------------------------+ void OnStart() { //--- for(int i=0; i<21; i++) PrintFormat("%2d %s",i,EnumToString(ENUM_ONNX_DATA_TYPE(i))); }
以下が出力です。
0: ONNX_DATA_TYPE_UNDEFINED 1: ONNX_DATA_TYPE_FLOAT 2: ONNX_DATA_TYPE_UINT8 3: ONNX_DATA_TYPE_INT8 4: ONNX_DATA_TYPE_UINT16 5: ONNX_DATA_TYPE_INT16 6: ONNX_DATA_TYPE_INT32 7: ONNX_DATA_TYPE_INT64 8: ONNX_DATA_TYPE_STRING 9: ONNX_DATA_TYPE_BOOL 10: ONNX_DATA_TYPE_FLOAT16 11: ONNX_DATA_TYPE_DOUBLE 12: ONNX_DATA_TYPE_UINT32 13: ONNX_DATA_TYPE_UINT64 14: ONNX_DATA_TYPE_COMPLEX64 15: ONNX_DATA_TYPE_COMPLEX128 16: ONNX_DATA_TYPE_BFLOAT16 17: ONNX_DATA_TYPE_FLOAT8E4M3FN 18: ONNX_DATA_TYPE_FLOAT8E4M3FNUZ 19: ONNX_DATA_TYPE_FLOAT8E5M2 20: ONNX_DATA_TYPE_FLOAT8E5M2FNUZ
こうして、このようなデータを使用してONNXモデルを実行することが可能になりました。
さらに、MQL5に、データ変換のための関数が追加されました。
bool ArrayToFP16(ushort &dst_array[],const float &src_array[],ENUM_FLOAT16_FORMAT fmt); bool ArrayToFP16(ushort &dst_array[],const double &src_array[],ENUM_FLOAT16_FORMAT fmt); bool ArrayToFP8(uchar &dst_array[],const float &src_array[],ENUM_FLOAT8_FORMAT fmt); bool ArrayToFP8(uchar &dst_array[],const double &src_array[],ENUM_FLOAT8_FORMAT fmt); bool ArrayFromFP16(float &dst_array[],const ushort &src_array[],ENUM_FLOAT16_FORMAT fmt); bool ArrayFromFP16(double &dst_array[],const ushort &src_array[],ENUM_FLOAT16_FORMAT fmt); bool ArrayFromFP8(float &dst_array[],const uchar &src_array[],ENUM_FLOAT8_FORMAT fmt); bool ArrayFromFP8(double &dst_array[],const uchar &src_array[],ENUM_FLOAT8_FORMAT fmt);
16ビットと8ビットの浮動小数点形式は異なる可能性があるので、変換関数のfmtパラメータは、どちらの形式の数値を処理する必要があるかを指定しなければなりません。
16ビットバージョンでは、新しいENUM_FLOAT16_FORMAT列挙が使用され、現在以下の値を持ちます。
- FLOAT_FP16:標準の16ビット形式(別名ハーフフロート)
- FLOAT_BFP16:特殊なBF16浮動小数点形式
- FLOAT_FP8_E4M3FN:8ビット浮動小数点数(4ビット指数、3ビット仮数)。通常は係数として使用
- FLOAT_FP8_E4M3FNUZ:8ビット浮動小数点数(4ビット指数、3ビット仮数)。NaNをサポート。負のゼロとInfのサポートなし。通常は係数として使用
- FLOAT_FP8_E5M2FN:8ビット浮動小数点数(5ビット指数、2ビット仮数)。NaNとInfをサポート。通常は勾配に使用
- FLOAT_FP8_E5M2FNUZ:8ビット浮動小数点数(5ビット指数、2ビット仮数)。NaNとInfをサポート。負のゼロのサポートなし。勾配に使用
1.1. FP16形式
FLOAT16形式とBFLOAT16形式は、浮動小数点数を表すために使用されるデータ型です。
「半精度浮動小数点」形式としても知られるFLOAT16は、16ビットで浮動小数点数を表します。この形式は、精度と計算効率のバランスを提供します。FLOAT16は、大量のデータを処理する際に高いパフォーマンスが求められる深層学習やニューラルネットワークで広く使用されています。この形式は、数値のサイズを小さくすることで計算の高速化を可能にします。これは、GPU (graphics processing unit)でディープニューラルネットワークを訓練する場合に特に重要です。
BFLOAT16(BF16)も16ビットを使用しますが、数値の表現方法がFLOAT16と異なります。この形式では、8ビットが指数を表すために割り当てられ、残りの7ビットが仮数を表すために使用されます。この形式は、深層学習や人工知能、特にGoogleTensor Processing Unit(TPU)プロセッサで使用するために開発されました。BFLOAT16は、ニューラルネットワークを訓練する際に優れたパフォーマンスを発揮し、計算の高速化に効果的に利用できます。
どちらの形式にも利点と限界があります。FLOAT16は精度が高いですが、ストレージと計算に多くのリソースを必要とします。一方、BFLOAT16は、データ処理のパフォーマンスと効率は高いですが、精度が落ちる可能性があります。
図1:浮動小数点数のビット表現の形式 - FLOAT16とBFLOAT16
表1:FLOAT16形式の浮動小数点数
1.1.1. FLOAT16に対するONNXキャスト演算子の実行テスト
例として、FLOAT16型のデータをfloat型とdouble型に変換する作業を考えてみましょう。
ONNXモデルのキャスト操作:
- https://github.com/onnx/onnx/tree/main/onnx/backend/test/data/node/test_cast_FLOAT16_to_FLOAT
- https://github.com/onnx/onnx/tree/main/onnx/backend/test/data/node/test_cast_FLOAT16_to_DOUBLE
図2:モデルtest_cast_FLOAT16_to_DOUBLE.onnxの入出力パラメータ
図3:モデルtest_cast_FLOAT16_to_FLOAT.onnxの入出力パラメータ
ONNXモデルのプロパティの説明からわかるように、入力にはONNX_DATA_TYPE_FLOAT16型のデータが必要で、モデルはONNX_DATA_TYPE_FLOAT形式の出力データを返します。
値を変換するには、FLOAT_FP16パラメータを持つ関数ArrayToFP16()とArrayFromFP16()を使用します。
例:
//+------------------------------------------------------------------+ //| TestCastFloat16.mq5 | //| Copyright 2024, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #property copyright "Copyright 2024, MetaQuotes Ltd." #property link "https://www.mql5.com" #property version "1.00" #resource "models\\test_cast_FLOAT16_to_DOUBLE.onnx" as const uchar ExtModel1[]; #resource "models\\test_cast_FLOAT16_to_FLOAT.onnx" as const uchar ExtModel2[]; //+------------------------------------------------------------------+ //| union for data conversion | //+------------------------------------------------------------------+ template<typename T> union U { uchar uc[sizeof(T)]; T value; }; //+------------------------------------------------------------------+ //| ArrayToString | //+------------------------------------------------------------------+ template<typename T> string ArrayToString(const T &data[],uint length=16) { string res; for(uint n=0; n<MathMin(length,data.Size()); n++) res+="," + StringFormat("%.2x",data[n]); StringSetCharacter(res,0,'['); return res+"]"; } //+------------------------------------------------------------------+ //| PatchONNXModel | //+------------------------------------------------------------------+ void PatchONNXModel(const uchar &original_model[],uchar &patched_model[]) { ArrayCopy(patched_model,original_model,0,0,WHOLE_ARRAY); //--- special ONNX model patch(IR=9,Opset=20) patched_model[1]=0x09; patched_model[ArraySize(patched_model)-1]=0x14; } //+------------------------------------------------------------------+ //| CreateModel | //+------------------------------------------------------------------+ bool CreateModel(long &model_handle,const uchar &model[]) { model_handle=INVALID_HANDLE; ulong flags=ONNX_DEFAULT; //ulong flags=ONNX_DEBUG_LOGS; //--- model_handle=OnnxCreateFromBuffer(model,flags); if(model_handle==INVALID_HANDLE) return(false); //--- return(true); } //+------------------------------------------------------------------+ //| PrepareShapes | //+------------------------------------------------------------------+ bool PrepareShapes(long model_handle) { ulong input_shape1[]= {3,4}; if(!OnnxSetInputShape(model_handle,0,input_shape1)) { PrintFormat("error in OnnxSetInputShape for input1. error code=%d",GetLastError()); //-- OnnxRelease(model_handle); return(false); } //--- ulong output_shape[]= {3,4}; if(!OnnxSetOutputShape(model_handle,0,output_shape)) { PrintFormat("error in OnnxSetOutputShape for output. error code=%d",GetLastError()); //-- OnnxRelease(model_handle); return(false); } //--- return(true); } //+------------------------------------------------------------------+ //| RunCastFloat16ToDouble | //+------------------------------------------------------------------+ bool RunCastFloat16ToDouble(long model_handle) { PrintFormat("test=%s",__FUNCTION__); double test_data[12]= {1,2,3,4,5,6,7,8,9,10,11,12}; ushort data_uint16[12]; if(!ArrayToFP16(data_uint16,test_data,FLOAT_FP16)) { Print("error in ArrayToFP16. error code=",GetLastError()); return(false); } Print("test array:"); ArrayPrint(test_data); Print("ArrayToFP16:"); ArrayPrint(data_uint16); U<ushort> input_float16_values[3*4]; U<double> output_double_values[3*4]; float test_data_float[]; if(!ArrayFromFP16(test_data_float,data_uint16,FLOAT_FP16)) { Print("error in ArrayFromFP16. error code=",GetLastError()); return(false); } for(int i=0; i<12; i++) { input_float16_values[i].value=data_uint16[i]; PrintFormat("%d input value =%f Hex float16 = %s ushort value=%d",i,test_data_float[i],ArrayToString(input_float16_values[i].uc),input_float16_values[i].value); } Print("ONNX input array:"); ArrayPrint(input_float16_values); bool res=OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float16_values,output_double_values); if(!res) { PrintFormat("error in OnnxRun. error code=%d",GetLastError()); return(false); } Print("ONNX output array:"); ArrayPrint(output_double_values); //--- double sum_error=0.0; for(int i=0; i<12; i++) { double delta=test_data[i]-output_double_values[i].value; sum_error+=MathAbs(delta); PrintFormat("%d output double %f = %s difference=%f",i,output_double_values[i].value,ArrayToString(output_double_values[i].uc),delta); } //--- PrintFormat("test=%s sum_error=%f",__FUNCTION__,sum_error); //--- return(true); } //+------------------------------------------------------------------+ //| RunCastFloat16ToFloat | //+------------------------------------------------------------------+ bool RunCastFloat16ToFloat(long model_handle) { PrintFormat("test=%s",__FUNCTION__); double test_data[12]= {1,2,3,4,5,6,7,8,9,10,11,12}; ushort data_uint16[12]; if(!ArrayToFP16(data_uint16,test_data,FLOAT_FP16)) { Print("error in ArrayToFP16. error code=",GetLastError()); return(false); } Print("test array:"); ArrayPrint(test_data); Print("ArrayToFP16:"); ArrayPrint(data_uint16); U<ushort> input_float16_values[3*4]; U<float> output_float_values[3*4]; float test_data_float[]; if(!ArrayFromFP16(test_data_float,data_uint16,FLOAT_FP16)) { Print("error in ArrayFromFP16. error code=",GetLastError()); return(false); } for(int i=0; i<12; i++) { input_float16_values[i].value=data_uint16[i]; PrintFormat("%d input value =%f Hex float16 = %s ushort value=%d",i,test_data_float[i],ArrayToString(input_float16_values[i].uc),input_float16_values[i].value); } Print("ONNX input array:"); ArrayPrint(input_float16_values); bool res=OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float16_values,output_float_values); if(!res) { PrintFormat("error in OnnxRun. error code=%d",GetLastError()); return(false); } Print("ONNX output array:"); ArrayPrint(output_float_values); //--- double sum_error=0.0; for(int i=0; i<12; i++) { double delta=test_data[i]-(double)output_float_values[i].value; sum_error+=MathAbs(delta); PrintFormat("%d output float %f = %s difference=%f",i,output_float_values[i].value,ArrayToString(output_float_values[i].uc),delta); } //--- PrintFormat("test=%s sum_error=%f",__FUNCTION__,sum_error); //--- return(true); } //+------------------------------------------------------------------+ //| TestCastFloat16ToFloat | //+------------------------------------------------------------------+ bool TestCastFloat16ToFloat(const uchar &res_model[]) { uchar model[]; PatchONNXModel(res_model,model); //--- get model handle long model_handle=INVALID_HANDLE; //--- get model handle if(!CreateModel(model_handle,model)) return(false); //--- prepare input and output shapes if(!PrepareShapes(model_handle)) return(false); //--- run ONNX model if(!RunCastFloat16ToFloat(model_handle)) return(false); //--- release model handle OnnxRelease(model_handle); //--- return(true); } //+------------------------------------------------------------------+ //| TestCastFloat16ToDouble | //+------------------------------------------------------------------+ bool TestCastFloat16ToDouble(const uchar &res_model[]) { uchar model[]; PatchONNXModel(res_model,model); //--- long model_handle=INVALID_HANDLE; //--- get model handle if(!CreateModel(model_handle,model)) return(false); //--- prepare input and output shapes if(!PrepareShapes(model_handle)) return(false); //--- run ONNX model if(!RunCastFloat16ToDouble(model_handle)) return(false); //--- release model handle OnnxRelease(model_handle); //--- return(true); } //+------------------------------------------------------------------+ //| Script program start function | //+------------------------------------------------------------------+ int OnStart(void) { if(!TestCastFloat16ToDouble(ExtModel1)) return 1; if(!TestCastFloat16ToFloat(ExtModel2)) return 1; //--- return 0; } //+------------------------------------------------------------------+
以下が出力です。
TestCastFloat16 (EURUSD,H1) test=RunCastFloat16ToDouble TestCastFloat16 (EURUSD,H1) test array: TestCastFloat16 (EURUSD,H1) 1.00000 2.00000 3.00000 4.00000 5.00000 6.00000 7.00000 8.00000 9.00000 10.00000 11.00000 12.00000 TestCastFloat16 (EURUSD,H1) ArrayToFP16: TestCastFloat16 (EURUSD,H1) 15360 16384 16896 17408 17664 17920 18176 18432 18560 18688 18816 18944 TestCastFloat16 (EURUSD,H1) 0 input value =1.000000 Hex float16 = [00,3c] ushort value=15360 TestCastFloat16 (EURUSD,H1) 1 input value =2.000000 Hex float16 = [00,40] ushort value=16384 TestCastFloat16 (EURUSD,H1) 2 input value =3.000000 Hex float16 = [00,42] ushort value=16896 TestCastFloat16 (EURUSD,H1) 3 input value =4.000000 Hex float16 = [00,44] ushort value=17408 TestCastFloat16 (EURUSD,H1) 4 input value =5.000000 Hex float16 = [00,45] ushort value=17664 TestCastFloat16 (EURUSD,H1) 5 input value =6.000000 Hex float16 = [00,46] ushort value=17920 TestCastFloat16 (EURUSD,H1) 6 input value =7.000000 Hex float16 = [00,47] ushort value=18176 TestCastFloat16 (EURUSD,H1) 7 input value =8.000000 Hex float16 = [00,48] ushort value=18432 TestCastFloat16 (EURUSD,H1) 8 input value =9.000000 Hex float16 = [80,48] ushort value=18560 TestCastFloat16 (EURUSD,H1) 9 input value =10.000000 Hex float16 = [00,49] ushort value=18688 TestCastFloat16 (EURUSD,H1) 10 input value =11.000000 Hex float16 = [80,49] ushort value=18816 TestCastFloat16 (EURUSD,H1) 11 input value =12.000000 Hex float16 = [00,4a] ushort value=18944 TestCastFloat16 (EURUSD,H1) ONNX input array: TestCastFloat16 (EURUSD,H1) [uc] [value] TestCastFloat16 (EURUSD,H1) [ 0] ... 15360 TestCastFloat16 (EURUSD,H1) [ 1] ... 16384 TestCastFloat16 (EURUSD,H1) [ 2] ... 16896 TestCastFloat16 (EURUSD,H1) [ 3] ... 17408 TestCastFloat16 (EURUSD,H1) [ 4] ... 17664 TestCastFloat16 (EURUSD,H1) [ 5] ... 17920 TestCastFloat16 (EURUSD,H1) [ 6] ... 18176 TestCastFloat16 (EURUSD,H1) [ 7] ... 18432 TestCastFloat16 (EURUSD,H1) [ 8] ... 18560 TestCastFloat16 (EURUSD,H1) [ 9] ... 18688 TestCastFloat16 (EURUSD,H1) [10] ... 18816 TestCastFloat16 (EURUSD,H1) [11] ... 18944 TestCastFloat16 (EURUSD,H1) ONNX output array: TestCastFloat16 (EURUSD,H1) [uc] [value] TestCastFloat16 (EURUSD,H1) [ 0] ... 1.00000 TestCastFloat16 (EURUSD,H1) [ 1] ... 2.00000 TestCastFloat16 (EURUSD,H1) [ 2] ... 3.00000 TestCastFloat16 (EURUSD,H1) [ 3] ... 4.00000 TestCastFloat16 (EURUSD,H1) [ 4] ... 5.00000 TestCastFloat16 (EURUSD,H1) [ 5] ... 6.00000 TestCastFloat16 (EURUSD,H1) [ 6] ... 7.00000 TestCastFloat16 (EURUSD,H1) [ 7] ... 8.00000 TestCastFloat16 (EURUSD,H1) [ 8] ... 9.00000 TestCastFloat16 (EURUSD,H1) [ 9] ... 10.00000 TestCastFloat16 (EURUSD,H1) [10] ... 11.00000 TestCastFloat16 (EURUSD,H1) [11] ... 12.00000 TestCastFloat16 (EURUSD,H1) 0 output double 1.000000 = [00,00,00,00,00,00,f0,3f] difference=0.000000 TestCastFloat16 (EURUSD,H1) 1 output double 2.000000 = [00,00,00,00,00,00,00,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 2 output double 3.000000 = [00,00,00,00,00,00,08,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 3 output double 4.000000 = [00,00,00,00,00,00,10,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 4 output double 5.000000 = [00,00,00,00,00,00,14,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 5 output double 6.000000 = [00,00,00,00,00,00,18,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 6 output double 7.000000 = [00,00,00,00,00,00,1c,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 7 output double 8.000000 = [00,00,00,00,00,00,20,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 8 output double 9.000000 = [00,00,00,00,00,00,22,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 9 output double 10.000000 = [00,00,00,00,00,00,24,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 10 output double 11.000000 = [00,00,00,00,00,00,26,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 11 output double 12.000000 = [00,00,00,00,00,00,28,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) test=RunCastFloat16ToDouble sum_error=0.000000 TestCastFloat16 (EURUSD,H1) test=RunCastFloat16ToFloat TestCastFloat16 (EURUSD,H1) test array: TestCastFloat16 (EURUSD,H1) 1.00000 2.00000 3.00000 4.00000 5.00000 6.00000 7.00000 8.00000 9.00000 10.00000 11.00000 12.00000 TestCastFloat16 (EURUSD,H1) ArrayToFP16: TestCastFloat16 (EURUSD,H1) 15360 16384 16896 17408 17664 17920 18176 18432 18560 18688 18816 18944 TestCastFloat16 (EURUSD,H1) 0 input value =1.000000 Hex float16 = [00,3c] ushort value=15360 TestCastFloat16 (EURUSD,H1) 1 input value =2.000000 Hex float16 = [00,40] ushort value=16384 TestCastFloat16 (EURUSD,H1) 2 input value =3.000000 Hex float16 = [00,42] ushort value=16896 TestCastFloat16 (EURUSD,H1) 3 input value =4.000000 Hex float16 = [00,44] ushort value=17408 TestCastFloat16 (EURUSD,H1) 4 input value =5.000000 Hex float16 = [00,45] ushort value=17664 TestCastFloat16 (EURUSD,H1) 5 input value =6.000000 Hex float16 = [00,46] ushort value=17920 TestCastFloat16 (EURUSD,H1) 6 input value =7.000000 Hex float16 = [00,47] ushort value=18176 TestCastFloat16 (EURUSD,H1) 7 input value =8.000000 Hex float16 = [00,48] ushort value=18432 TestCastFloat16 (EURUSD,H1) 8 input value =9.000000 Hex float16 = [80,48] ushort value=18560 TestCastFloat16 (EURUSD,H1) 9 input value =10.000000 Hex float16 = [00,49] ushort value=18688 TestCastFloat16 (EURUSD,H1) 10 input value =11.000000 Hex float16 = [80,49] ushort value=18816 TestCastFloat16 (EURUSD,H1) 11 input value =12.000000 Hex float16 = [00,4a] ushort value=18944 TestCastFloat16 (EURUSD,H1) ONNX input array: TestCastFloat16 (EURUSD,H1) [uc] [value] TestCastFloat16 (EURUSD,H1) [ 0] ... 15360 TestCastFloat16 (EURUSD,H1) [ 1] ... 16384 TestCastFloat16 (EURUSD,H1) [ 2] ... 16896 TestCastFloat16 (EURUSD,H1) [ 3] ... 17408 TestCastFloat16 (EURUSD,H1) [ 4] ... 17664 TestCastFloat16 (EURUSD,H1) [ 5] ... 17920 TestCastFloat16 (EURUSD,H1) [ 6] ... 18176 TestCastFloat16 (EURUSD,H1) [ 7] ... 18432 TestCastFloat16 (EURUSD,H1) [ 8] ... 18560 TestCastFloat16 (EURUSD,H1) [ 9] ... 18688 TestCastFloat16 (EURUSD,H1) [10] ... 18816 TestCastFloat16 (EURUSD,H1) [11] ... 18944 TestCastFloat16 (EURUSD,H1) ONNX output array: TestCastFloat16 (EURUSD,H1) [uc] [value] TestCastFloat16 (EURUSD,H1) [ 0] ... 1.00000 TestCastFloat16 (EURUSD,H1) [ 1] ... 2.00000 TestCastFloat16 (EURUSD,H1) [ 2] ... 3.00000 TestCastFloat16 (EURUSD,H1) [ 3] ... 4.00000 TestCastFloat16 (EURUSD,H1) [ 4] ... 5.00000 TestCastFloat16 (EURUSD,H1) [ 5] ... 6.00000 TestCastFloat16 (EURUSD,H1) [ 6] ... 7.00000 TestCastFloat16 (EURUSD,H1) [ 7] ... 8.00000 TestCastFloat16 (EURUSD,H1) [ 8] ... 9.00000 TestCastFloat16 (EURUSD,H1) [ 9] ... 10.00000 TestCastFloat16 (EURUSD,H1) [10] ... 11.00000 TestCastFloat16 (EURUSD,H1) [11] ... 12.00000 TestCastFloat16 (EURUSD,H1) 0 output float 1.000000 = [00,00,80,3f] difference=0.000000 TestCastFloat16 (EURUSD,H1) 1 output float 2.000000 = [00,00,00,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 2 output float 3.000000 = [00,00,40,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 3 output float 4.000000 = [00,00,80,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 4 output float 5.000000 = [00,00,a0,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 5 output float 6.000000 = [00,00,c0,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 6 output float 7.000000 = [00,00,e0,40] difference=0.000000 TestCastFloat16 (EURUSD,H1) 7 output float 8.000000 = [00,00,00,41] difference=0.000000 TestCastFloat16 (EURUSD,H1) 8 output float 9.000000 = [00,00,10,41] difference=0.000000 TestCastFloat16 (EURUSD,H1) 9 output float 10.000000 = [00,00,20,41] difference=0.000000 TestCastFloat16 (EURUSD,H1) 10 output float 11.000000 = [00,00,30,41] difference=0.000000 TestCastFloat16 (EURUSD,H1) 11 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat16 (EURUSD,H1) test=RunCastFloat16ToFloat sum_error=0.000000
1.1.2. BFLOAT16に対するONNXキャスト演算子の実行テスト
この例では、BFLOAT16からfloatへの変換を調べます。
ONNXモデルのキャスト操作:
図4:モデルtest_cast_BFLOAT16_to_FLOAT.onnxの入出力パラメータ
ONNX_DATA_TYPE_BFLOAT16型の入力データが必要で、モデルはONNX_DATA_TYPE_FLOAT型の出力データを返します。
値を変換するには、BFLOAT_FP16をパラメータとする関数ArrayToFP16()とArrayFromFP16()を使用します。
//+------------------------------------------------------------------+ //| TestCastBFloat16.mq5 | //| Copyright 2024, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #property copyright "Copyright 2024, MetaQuotes Ltd." #property link "https://www.mql5.com" #property version "1.00" #resource "models\\test_cast_BFLOAT16_to_FLOAT.onnx" as const uchar ExtModel1[]; //+------------------------------------------------------------------+ //| union for data conversion | //+------------------------------------------------------------------+ template<typename T> union U { uchar uc[sizeof(T)]; T value; }; //+------------------------------------------------------------------+ //| ArrayToString | //+------------------------------------------------------------------+ template<typename T> string ArrayToString(const T &data[],uint length=16) { string res; for(uint n=0; n<MathMin(length,data.Size()); n++) res+="," + StringFormat("%.2x",data[n]); StringSetCharacter(res,0,'['); return res+"]"; } //+------------------------------------------------------------------+ //| PatchONNXModel | //+------------------------------------------------------------------+ void PatchONNXModel(const uchar &original_model[],uchar &patched_model[]) { ArrayCopy(patched_model,original_model,0,0,WHOLE_ARRAY); //--- special ONNX model patch(IR=9,Opset=20) patched_model[1]=0x09; patched_model[ArraySize(patched_model)-1]=0x14; } //+------------------------------------------------------------------+ //| CreateModel | //+------------------------------------------------------------------+ bool CreateModel(long &model_handle,const uchar &model[]) { model_handle=INVALID_HANDLE; ulong flags=ONNX_DEFAULT; //ulong flags=ONNX_DEBUG_LOGS; //--- model_handle=OnnxCreateFromBuffer(model,flags); if(model_handle==INVALID_HANDLE) return(false); //--- return(true); } //+------------------------------------------------------------------+ //| PrepareShapes | //+------------------------------------------------------------------+ bool PrepareShapes(long model_handle) { ulong input_shape1[]= {3,4}; if(!OnnxSetInputShape(model_handle,0,input_shape1)) { PrintFormat("error in OnnxSetInputShape for input1. error code=%d",GetLastError()); //-- OnnxRelease(model_handle); return(false); } //--- ulong output_shape[]= {3,4}; if(!OnnxSetOutputShape(model_handle,0,output_shape)) { PrintFormat("error in OnnxSetOutputShape for output. error code=%d",GetLastError()); //-- OnnxRelease(model_handle); return(false); } //--- return(true); } //+------------------------------------------------------------------+ //| RunCastBFloat16ToFloat | //+------------------------------------------------------------------+ bool RunCastBFloat16ToFloat(long model_handle) { PrintFormat("test=%s",__FUNCTION__); double test_data[12]= {1,2,3,4,5,6,7,8,9,10,11,12}; ushort data_uint16[12]; if(!ArrayToFP16(data_uint16,test_data,FLOAT_BFP16)) { Print("error in ArrayToFP16. error code=",GetLastError()); return(false); } Print("test array:"); ArrayPrint(test_data); Print("ArrayToFP16:"); ArrayPrint(data_uint16); U<ushort> input_float16_values[3*4]; U<float> output_float_values[3*4]; float test_data_float[]; if(!ArrayFromFP16(test_data_float,data_uint16,FLOAT_BFP16)) { Print("error in ArrayFromFP16. error code=",GetLastError()); return(false); } for(int i=0; i<12; i++) { input_float16_values[i].value=data_uint16[i]; PrintFormat("%d input value =%f Hex float16 = %s ushort value=%d",i,test_data_float[i],ArrayToString(input_float16_values[i].uc),input_float16_values[i].value); } Print("ONNX input array:"); ArrayPrint(input_float16_values); bool res=OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float16_values,output_float_values); if(!res) { PrintFormat("error in OnnxRun. error code=%d",GetLastError()); return(false); } Print("ONNX output array:"); ArrayPrint(output_float_values); //--- double sum_error=0.0; for(int i=0; i<12; i++) { double delta=test_data[i]-(double)output_float_values[i].value; sum_error+=MathAbs(delta); PrintFormat("%d output float %f = %s difference=%f",i,output_float_values[i].value,ArrayToString(output_float_values[i].uc),delta); } //--- PrintFormat("test=%s sum_error=%f",__FUNCTION__,sum_error); //--- return(true); } //+------------------------------------------------------------------+ //| Script program start function | //+------------------------------------------------------------------+ int OnStart(void) { uchar model[]; PatchONNXModel(ExtModel1,model); //--- get model handle long model_handle=INVALID_HANDLE; //--- get model handle if(!CreateModel(model_handle,model)) return 1; //--- prepare input and output shapes if(!PrepareShapes(model_handle)) return 1; //--- run ONNX model if(!RunCastBFloat16ToFloat(model_handle)) return 1; //--- release model handle OnnxRelease(model_handle); //--- return 0; } //+------------------------------------------------------------------+以下が出力です。
TestCastBFloat16 (EURUSD,H1) test=RunCastBFloat16ToFloat TestCastBFloat16 (EURUSD,H1) test array: TestCastBFloat16 (EURUSD,H1) 1.00000 2.00000 3.00000 4.00000 5.00000 6.00000 7.00000 8.00000 9.00000 10.00000 11.00000 12.00000 TestCastBFloat16 (EURUSD,H1) ArrayToFP16: TestCastBFloat16 (EURUSD,H1) 16256 16384 16448 16512 16544 16576 16608 16640 16656 16672 16688 16704 TestCastBFloat16 (EURUSD,H1) 0 input value =1.000000 Hex float16 = [80,3f] ushort value=16256 TestCastBFloat16 (EURUSD,H1) 1 input value =2.000000 Hex float16 = [00,40] ushort value=16384 TestCastBFloat16 (EURUSD,H1) 2 input value =3.000000 Hex float16 = [40,40] ushort value=16448 TestCastBFloat16 (EURUSD,H1) 3 input value =4.000000 Hex float16 = [80,40] ushort value=16512 TestCastBFloat16 (EURUSD,H1) 4 input value =5.000000 Hex float16 = [a0,40] ushort value=16544 TestCastBFloat16 (EURUSD,H1) 5 input value =6.000000 Hex float16 = [c0,40] ushort value=16576 TestCastBFloat16 (EURUSD,H1) 6 input value =7.000000 Hex float16 = [e0,40] ushort value=16608 TestCastBFloat16 (EURUSD,H1) 7 input value =8.000000 Hex float16 = [00,41] ushort value=16640 TestCastBFloat16 (EURUSD,H1) 8 input value =9.000000 Hex float16 = [10,41] ushort value=16656 TestCastBFloat16 (EURUSD,H1) 9 input value =10.000000 Hex float16 = [20,41] ushort value=16672 TestCastBFloat16 (EURUSD,H1) 10 input value =11.000000 Hex float16 = [30,41] ushort value=16688 TestCastBFloat16 (EURUSD,H1) 11 input value =12.000000 Hex float16 = [40,41] ushort value=16704 TestCastBFloat16 (EURUSD,H1) ONNX input array: TestCastBFloat16 (EURUSD,H1) [uc] [value] TestCastBFloat16 (EURUSD,H1) [ 0] ... 16256 TestCastBFloat16 (EURUSD,H1) [ 1] ... 16384 TestCastBFloat16 (EURUSD,H1) [ 2] ... 16448 TestCastBFloat16 (EURUSD,H1) [ 3] ... 16512 TestCastBFloat16 (EURUSD,H1) [ 4] ... 16544 TestCastBFloat16 (EURUSD,H1) [ 5] ... 16576 TestCastBFloat16 (EURUSD,H1) [ 6] ... 16608 TestCastBFloat16 (EURUSD,H1) [ 7] ... 16640 TestCastBFloat16 (EURUSD,H1) [ 8] ... 16656 TestCastBFloat16 (EURUSD,H1) [ 9] ... 16672 TestCastBFloat16 (EURUSD,H1) [10] ... 16688 TestCastBFloat16 (EURUSD,H1) [11] ... 16704 TestCastBFloat16 (EURUSD,H1) ONNX output array: TestCastBFloat16 (EURUSD,H1) [uc] [value] TestCastBFloat16 (EURUSD,H1) [ 0] ... 1.00000 TestCastBFloat16 (EURUSD,H1) [ 1] ... 2.00000 TestCastBFloat16 (EURUSD,H1) [ 2] ... 3.00000 TestCastBFloat16 (EURUSD,H1) [ 3] ... 4.00000 TestCastBFloat16 (EURUSD,H1) [ 4] ... 5.00000 TestCastBFloat16 (EURUSD,H1) [ 5] ... 6.00000 TestCastBFloat16 (EURUSD,H1) [ 6] ... 7.00000 TestCastBFloat16 (EURUSD,H1) [ 7] ... 8.00000 TestCastBFloat16 (EURUSD,H1) [ 8] ... 9.00000 TestCastBFloat16 (EURUSD,H1) [ 9] ... 10.00000 TestCastBFloat16 (EURUSD,H1) [10] ... 11.00000 TestCastBFloat16 (EURUSD,H1) [11] ... 12.00000 TestCastBFloat16 (EURUSD,H1) 0 output float 1.000000 = [00,00,80,3f] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 1 output float 2.000000 = [00,00,00,40] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 2 output float 3.000000 = [00,00,40,40] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 3 output float 4.000000 = [00,00,80,40] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 4 output float 5.000000 = [00,00,a0,40] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 5 output float 6.000000 = [00,00,c0,40] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 6 output float 7.000000 = [00,00,e0,40] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 7 output float 8.000000 = [00,00,00,41] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 8 output float 9.000000 = [00,00,10,41] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 9 output float 10.000000 = [00,00,20,41] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 10 output float 11.000000 = [00,00,30,41] difference=0.000000 TestCastBFloat16 (EURUSD,H1) 11 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastBFloat16 (EURUSD,H1) test=RunCastBFloat16ToFloat sum_error=0.000000
1.2. FP8形式
現代の言語モデルは、何十億ものパラメータを含むことができます。FP16の数値を使用したモデルの訓練が効果的であることはすでに証明されています。16ビット浮動小数点数からFP8に移行することで、メモリ要件を半減し、訓練とモデル実行を高速化することができます。
FP8形式(8ビット浮動小数点数)は、浮動小数点数を表現するために使用されるデータ型の1つです。FP8では、各数値は8ビットのデータで表現され、通常、符号、指数、仮数の3つの要素に分けられます。この形式は、精度とストレージ効率の妥協点を提供し、メモリや計算リソースを節約する必要があるアプリケーションでの使用に魅力的です。
FP8の主な利点の1つは、大量のデータを効率的に処理できることです。そのコンパクトな数値表現により、FP8は必要なメモリを削減し、計算を高速化します。これは、大規模なデータセットを処理することが一般的な機械学習や人工知能アプリケーションでは特に重要です。
さらにFP8は、算術計算や信号処理などの低レベル演算の実装にも役立ちます。コンパクトな形式なので、組み込みシステムやリソースが限られているアプリケーションでの使用に適しています。ただし、FP8の精度に限界があることは注目に値します。科学計算や金融分析など、高精度の計算が要求されるアプリケーションでは、FP8の使用は不十分な場合があります。
1.2.1. fp8_e5m2およびfp8_e4m3 FP8形式
2022年には、4バイトで格納されるfloat32とは異なり、1バイトで格納される浮動小数点数を紹介する2つの記事が発表されました。
NVIDIA、Intel、ARMによる「FP8 Formats for Deep Learning」稿(2022年)では、IEEE仕様に従った2つの型が紹介されています。最初の型はE4M3で、符号が1ビット、指数が4ビット、仮数が3ビットです。2つ目の型はE5M2で、符号が1ビット、指数が5ビット、仮数が2ビットを使用します。最初の型は通常重みに使用され、2番目の型は勾配に使用されます。
2つ目の記事「8-bit Numerical Formats For Deep Neural Networks」では、同様の型を紹介しています。IEEE規格では、+0(または整数0)と-0(または整数128)に同じ値を割り当てています。この記事では、この2つの数値に異なるfloat値を割り当てることを提案しています。さらに、指数と仮数の間のさまざまな分割が検討され、E4M3とE5M2が最良であることが示されました。
その結果、ONNXは4つの新しい型を導入しました(バージョン1.15.0から)。
- E4M3FN:符号が1ビット、指数が4ビット、仮数が3ビット、NaN値のみ、無限(FN)なし
- E4M3FNUZ:符号が1ビット、指数が4ビット、仮数が3ビット、NaN値のみ、無限(FN)なし、負のゼロ(UZ)なし
- E5M2:符号が1ビット、指数が5ビット、仮数が2ビット
- E5M2FNUZ:符号が1ビット、指数が5ビット、仮数が2ビット、NaN値、無限(FN)なし、負のゼロ(UZ)なし
実装は通常、ハードウェアに依存します。NVIDIA、Intel、ArmはE4M3FNを実装していますが、E5M2は最新のGPU(グラフィックスプロセッシングユニット)に実装されています。GraphCoreでも同様、E4M3FNUZとE5M2FNUZが実装されています。
「NVIDIA Hopper:H100 and FP8 Support」稿にしたがって、FP8型に関する主な情報を簡単にまとめてみましょう。
図5:FP8形式のビット表現
表3:E5M2形式の浮動小数点数
表4:E4M3形式の浮動小数点数
正数の範囲の比較:FP8_E4M3とFP8_E5M2を図6に示します。
図6:FP8正数の範囲比較 (参照)
FP8_E5M2形式とFP8_E4M3形式の数値に対する算術演算(Add、Mul、Div)の精度の比較を図7に示します。
図7:float8_e5m2形式とfloat8_e4m3形式の数値に対する算術演算精度の比較 (参照)
FP8形式で推奨される数字の使用法は、次の通りです。
- 重みテンソルおよび活性化テンソル:E4M3
- 勾配テンソル:E5M2
1.2.2. FLOAT8用ONNX演算子キャストの実行テスト
この例では、様々な型のFLOAT8からfloatへの変換を考えます。
ONNXモデルのキャスト操作:
- https://github.com/onnx/onnx/tree/main/onnx/backend/test/data/node/test_cast_FLOAT8E4M3FN_to_FLOAT.onnx
- https://github.com/onnx/onnx/tree/main/onnx/backend/test/data/node/test_cast_FLOAT8E4M3FNUZ_to_FLOAT.onnx
- https://github.com/onnx/onnx/tree/main/onnx/backend/test/data/node/test_cast_FLOAT8E5M2_to_FLOAT.onnx
- https://github.com/onnx/onnx/tree/main/onnx/backend/test/data/node/test_cast_FLOAT8E5M2FNUZ_to_FLOAT.onnx
図8:MetaEditorにおけるモデルtest_cast_FLOAT8E4M3FN_to_FLOAT.onnxの入出力パラメータ
図9:MetaEditorにおけるモデルtest_cast_FLOAT8E4M3FNUZ_to_FLOAT.onnxの入出力パラメータ
図10:MetaEditorにおけるモデルtest_cast_FLOAT8E5M2_to_FLOAT.onnxの入出力パラメータ
図11:MetaEditorにおけるモデルtest_cast_FLOAT8E5M2FNUZ_to_FLOAT.onnxの入出力パラメータ
例:
//+------------------------------------------------------------------+ //| TestCastBFloat8.mq5 | //| Copyright 2024, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #property copyright "Copyright 2024, MetaQuotes Ltd." #property link "https://www.mql5.com" #property version "1.00" #resource "models\\test_cast_FLOAT8E4M3FN_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E4M3FN_to_FLOAT[]; #resource "models\\test_cast_FLOAT8E4M3FNUZ_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E4M3FNUZ_to_FLOAT[]; #resource "models\\test_cast_FLOAT8E5M2_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E5M2_to_FLOAT[]; #resource "models\\test_cast_FLOAT8E5M2FNUZ_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E5M2FNUZ_to_FLOAT[]; #define TEST_PASSED 0 #define TEST_FAILED 1 //+------------------------------------------------------------------+ //| union for data conversion | //+------------------------------------------------------------------+ template<typename T> union U { uchar uc[sizeof(T)]; T value; }; //+------------------------------------------------------------------+ //| ArrayToHexString | //+------------------------------------------------------------------+ template<typename T> string ArrayToHexString(const T &data[],uint length=16) { string res; for(uint n=0; n<MathMin(length,data.Size()); n++) res+="," + StringFormat("%.2x",data[n]); StringSetCharacter(res,0,'['); return(res+"]"); } //+------------------------------------------------------------------+ //| ArrayToString | //+------------------------------------------------------------------+ template<typename T> string ArrayToString(const U<T> &data[],uint length=16) { string res; for(uint n=0; n<MathMin(length,data.Size()); n++) res+="," + (string)data[n].value; StringSetCharacter(res,0,'['); return(res+"]"); } //+------------------------------------------------------------------+ //| PatchONNXModel | //+------------------------------------------------------------------+ long CreatePatchedModel(const uchar &original_model[]) { uchar patched_model[]; ArrayCopy(patched_model,original_model); //--- special ONNX model patch(IR=9,Opset=20) patched_model[1]=0x09; patched_model[ArraySize(patched_model)-1]=0x14; return(OnnxCreateFromBuffer(patched_model,ONNX_DEFAULT)); } //+------------------------------------------------------------------+ //| PrepareShapes | //+------------------------------------------------------------------+ bool PrepareShapes(long model_handle) { //--- configure input shape ulong input_shape[]= {3,5}; if(!OnnxSetInputShape(model_handle,0,input_shape)) { PrintFormat("error in OnnxSetInputShape for input1. error code=%d",GetLastError()); OnnxRelease(model_handle); return(false); } //--- configure output shape ulong output_shape[]= {3,5}; if(!OnnxSetOutputShape(model_handle,0,output_shape)) { PrintFormat("error in OnnxSetOutputShape for output. error code=%d",GetLastError()); OnnxRelease(model_handle); return(false); } return(true); } //+------------------------------------------------------------------+ //| RunCastFloat8Float | //+------------------------------------------------------------------+ bool RunCastFloat8ToFloat(long model_handle,const ENUM_FLOAT8_FORMAT fmt) { PrintFormat("TEST: %s(%s)",__FUNCTION__,EnumToString(fmt)); //--- float test_data[15] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}; uchar data_float8[15] = {}; if(!ArrayToFP8(data_float8,test_data,fmt)) { Print("error in ArrayToFP8. error code=",GetLastError()); OnnxRelease(model_handle); return(false); } U<uchar> input_float8_values[3*5]; U<float> output_float_values[3*5]; float test_data_float[]; //--- convert float8 to float if(!ArrayFromFP8(test_data_float,data_float8,fmt)) { Print("error in ArrayFromFP8. error code=",GetLastError()); OnnxRelease(model_handle); return(false); } for(uint i=0; i<data_float8.Size(); i++) { input_float8_values[i].value=data_float8[i]; PrintFormat("%d input value =%f Hex float8 = %s ushort value=%d",i,test_data_float[i],ArrayToHexString(input_float8_values[i].uc),input_float8_values[i].value); } Print("ONNX input array: ",ArrayToString(input_float8_values)); //--- execute model (convert float8 to float using ONNX) if(!OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float8_values,output_float_values)) { PrintFormat("error in OnnxRun. error code=%d",GetLastError()); OnnxRelease(model_handle); return(false); } Print("ONNX output array: ",ArrayToString(output_float_values)); //--- calculate error (compare ONNX and ArrayFromFP8 results) double sum_error=0.0; for(uint i=0; i<test_data.Size(); i++) { double delta=test_data_float[i]-(double)output_float_values[i].value; sum_error+=MathAbs(delta); PrintFormat("%d output float %f = %s difference=%f",i,output_float_values[i].value,ArrayToHexString(output_float_values[i].uc),delta); } //--- PrintFormat("%s(%s): sum_error=%f\n",__FUNCTION__,EnumToString(fmt),sum_error); return(true); } //+------------------------------------------------------------------+ //| TestModel | //+------------------------------------------------------------------+ bool TestModel(const uchar &model[],const ENUM_FLOAT8_FORMAT fmt) { //--- create patched model long model_handle=CreatePatchedModel(model); if(model_handle==INVALID_HANDLE) return(false); //--- prepare input and output shapes if(!PrepareShapes(model_handle)) return(false); //--- run ONNX model if(!RunCastFloat8ToFloat(model_handle,fmt)) return(false); //--- release model handle OnnxRelease(model_handle); return(true); } //+------------------------------------------------------------------+ //| Script program start function | //+------------------------------------------------------------------+ int OnStart(void) { //--- run ONNX model if(!TestModel(ExtModel_FLOAT8E4M3FN_to_FLOAT,FLOAT_FP8_E4M3FN)) return(TEST_FAILED); //--- run ONNX model if(!TestModel(ExtModel_FLOAT8E4M3FNUZ_to_FLOAT,FLOAT_FP8_E4M3FNUZ)) return(TEST_FAILED); //--- run ONNX model if(!TestModel(ExtModel_FLOAT8E5M2_to_FLOAT,FLOAT_FP8_E5M2FN)) return(TEST_FAILED); //--- run ONNX model if(!TestModel(ExtModel_FLOAT8E5M2FNUZ_to_FLOAT,FLOAT_FP8_E5M2FNUZ)) return(TEST_FAILED); return(TEST_PASSED); } //+------------------------------------------------------------------+以下が出力です。
TestCastFloat8 (EURUSD,H1) TEST: RunCastFloat8ToFloat(FLOAT_FP8_E4M3FN) TestCastFloat8 (EURUSD,H1) 0 input value =1.000000 Hex float8 = [38] ushort value=56 TestCastFloat8 (EURUSD,H1) 1 input value =2.000000 Hex float8 = [40] ushort value=64 TestCastFloat8 (EURUSD,H1) 2 input value =3.000000 Hex float8 = [44] ushort value=68 TestCastFloat8 (EURUSD,H1) 3 input value =4.000000 Hex float8 = [48] ushort value=72 TestCastFloat8 (EURUSD,H1) 4 input value =5.000000 Hex float8 = [4a] ushort value=74 TestCastFloat8 (EURUSD,H1) 5 input value =6.000000 Hex float8 = [4c] ushort value=76 TestCastFloat8 (EURUSD,H1) 6 input value =7.000000 Hex float8 = [4e] ushort value=78 TestCastFloat8 (EURUSD,H1) 7 input value =8.000000 Hex float8 = [50] ushort value=80 TestCastFloat8 (EURUSD,H1) 8 input value =9.000000 Hex float8 = [51] ushort value=81 TestCastFloat8 (EURUSD,H1) 9 input value =10.000000 Hex float8 = [52] ushort value=82 TestCastFloat8 (EURUSD,H1) 10 input value =11.000000 Hex float8 = [53] ushort value=83 TestCastFloat8 (EURUSD,H1) 11 input value =12.000000 Hex float8 = [54] ushort value=84 TestCastFloat8 (EURUSD,H1) 12 input value =13.000000 Hex float8 = [55] ushort value=85 TestCastFloat8 (EURUSD,H1) 13 input value =14.000000 Hex float8 = [56] ushort value=86 TestCastFloat8 (EURUSD,H1) 14 input value =15.000000 Hex float8 = [57] ushort value=87 TestCastFloat8 (EURUSD,H1) ONNX input array: [56,64,68,72,74,76,78,80,81,82,83,84,85,86,87] TestCastFloat8 (EURUSD,H1) ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0] TestCastFloat8 (EURUSD,H1) 0 output float 1.000000 = [00,00,80,3f] difference=0.000000 TestCastFloat8 (EURUSD,H1) 1 output float 2.000000 = [00,00,00,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 2 output float 3.000000 = [00,00,40,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 3 output float 4.000000 = [00,00,80,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 4 output float 5.000000 = [00,00,a0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 5 output float 6.000000 = [00,00,c0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 6 output float 7.000000 = [00,00,e0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 7 output float 8.000000 = [00,00,00,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 8 output float 9.000000 = [00,00,10,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 9 output float 10.000000 = [00,00,20,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 10 output float 11.000000 = [00,00,30,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 11 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 12 output float 13.000000 = [00,00,50,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 13 output float 14.000000 = [00,00,60,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 14 output float 15.000000 = [00,00,70,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) RunCastFloat8ToFloat(FLOAT_FP8_E4M3FN): sum_error=0.000000 TestCastFloat8 (EURUSD,H1) TestCastFloat8 (EURUSD,H1) TEST: RunCastFloat8ToFloat(FLOAT_FP8_E4M3FNUZ) TestCastFloat8 (EURUSD,H1) 0 input value =1.000000 Hex float8 = [40] ushort value=64 TestCastFloat8 (EURUSD,H1) 1 input value =2.000000 Hex float8 = [48] ushort value=72 TestCastFloat8 (EURUSD,H1) 2 input value =3.000000 Hex float8 = [4c] ushort value=76 TestCastFloat8 (EURUSD,H1) 3 input value =4.000000 Hex float8 = [50] ushort value=80 TestCastFloat8 (EURUSD,H1) 4 input value =5.000000 Hex float8 = [52] ushort value=82 TestCastFloat8 (EURUSD,H1) 5 input value =6.000000 Hex float8 = [54] ushort value=84 TestCastFloat8 (EURUSD,H1) 6 input value =7.000000 Hex float8 = [56] ushort value=86 TestCastFloat8 (EURUSD,H1) 7 input value =8.000000 Hex float8 = [58] ushort value=88 TestCastFloat8 (EURUSD,H1) 8 input value =9.000000 Hex float8 = [59] ushort value=89 TestCastFloat8 (EURUSD,H1) 9 input value =10.000000 Hex float8 = [5a] ushort value=90 TestCastFloat8 (EURUSD,H1) 10 input value =11.000000 Hex float8 = [5b] ushort value=91 TestCastFloat8 (EURUSD,H1) 11 input value =12.000000 Hex float8 = [5c] ushort value=92 TestCastFloat8 (EURUSD,H1) 12 input value =13.000000 Hex float8 = [5d] ushort value=93 TestCastFloat8 (EURUSD,H1) 13 input value =14.000000 Hex float8 = [5e] ushort value=94 TestCastFloat8 (EURUSD,H1) 14 input value =15.000000 Hex float8 = [5f] ushort value=95 TestCastFloat8 (EURUSD,H1) ONNX input array: [64,72,76,80,82,84,86,88,89,90,91,92,93,94,95] TestCastFloat8 (EURUSD,H1) ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0] TestCastFloat8 (EURUSD,H1) 0 output float 1.000000 = [00,00,80,3f] difference=0.000000 TestCastFloat8 (EURUSD,H1) 1 output float 2.000000 = [00,00,00,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 2 output float 3.000000 = [00,00,40,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 3 output float 4.000000 = [00,00,80,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 4 output float 5.000000 = [00,00,a0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 5 output float 6.000000 = [00,00,c0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 6 output float 7.000000 = [00,00,e0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 7 output float 8.000000 = [00,00,00,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 8 output float 9.000000 = [00,00,10,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 9 output float 10.000000 = [00,00,20,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 10 output float 11.000000 = [00,00,30,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 11 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 12 output float 13.000000 = [00,00,50,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 13 output float 14.000000 = [00,00,60,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 14 output float 15.000000 = [00,00,70,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) RunCastFloat8ToFloat(FLOAT_FP8_E4M3FNUZ): sum_error=0.000000 TestCastFloat8 (EURUSD,H1) TestCastFloat8 (EURUSD,H1) TEST: RunCastFloat8ToFloat(FLOAT_FP8_E5M2FN) TestCastFloat8 (EURUSD,H1) 0 input value =1.000000 Hex float8 = [3c] ushort value=60 TestCastFloat8 (EURUSD,H1) 1 input value =2.000000 Hex float8 = [40] ushort value=64 TestCastFloat8 (EURUSD,H1) 2 input value =3.000000 Hex float8 = [42] ushort value=66 TestCastFloat8 (EURUSD,H1) 3 input value =4.000000 Hex float8 = [44] ushort value=68 TestCastFloat8 (EURUSD,H1) 4 input value =5.000000 Hex float8 = [45] ushort value=69 TestCastFloat8 (EURUSD,H1) 5 input value =6.000000 Hex float8 = [46] ushort value=70 TestCastFloat8 (EURUSD,H1) 6 input value =7.000000 Hex float8 = [47] ushort value=71 TestCastFloat8 (EURUSD,H1) 7 input value =8.000000 Hex float8 = [48] ushort value=72 TestCastFloat8 (EURUSD,H1) 8 input value =8.000000 Hex float8 = [48] ushort value=72 TestCastFloat8 (EURUSD,H1) 9 input value =10.000000 Hex float8 = [49] ushort value=73 TestCastFloat8 (EURUSD,H1) 10 input value =12.000000 Hex float8 = [4a] ushort value=74 TestCastFloat8 (EURUSD,H1) 11 input value =12.000000 Hex float8 = [4a] ushort value=74 TestCastFloat8 (EURUSD,H1) 12 input value =12.000000 Hex float8 = [4a] ushort value=74 TestCastFloat8 (EURUSD,H1) 13 input value =14.000000 Hex float8 = [4b] ushort value=75 TestCastFloat8 (EURUSD,H1) 14 input value =16.000000 Hex float8 = [4c] ushort value=76 TestCastFloat8 (EURUSD,H1) ONNX input array: [60,64,66,68,69,70,71,72,72,73,74,74,74,75,76] TestCastFloat8 (EURUSD,H1) ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,8.0,10.0,12.0,12.0,12.0,14.0,16.0] TestCastFloat8 (EURUSD,H1) 0 output float 1.000000 = [00,00,80,3f] difference=0.000000 TestCastFloat8 (EURUSD,H1) 1 output float 2.000000 = [00,00,00,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 2 output float 3.000000 = [00,00,40,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 3 output float 4.000000 = [00,00,80,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 4 output float 5.000000 = [00,00,a0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 5 output float 6.000000 = [00,00,c0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 6 output float 7.000000 = [00,00,e0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 7 output float 8.000000 = [00,00,00,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 8 output float 8.000000 = [00,00,00,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 9 output float 10.000000 = [00,00,20,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 10 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 11 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 12 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 13 output float 14.000000 = [00,00,60,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 14 output float 16.000000 = [00,00,80,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) RunCastFloat8ToFloat(FLOAT_FP8_E5M2FN): sum_error=0.000000 TestCastFloat8 (EURUSD,H1) TestCastFloat8 (EURUSD,H1) TEST: RunCastFloat8ToFloat(FLOAT_FP8_E5M2FNUZ) TestCastFloat8 (EURUSD,H1) 0 input value =1.000000 Hex float8 = [40] ushort value=64 TestCastFloat8 (EURUSD,H1) 1 input value =2.000000 Hex float8 = [44] ushort value=68 TestCastFloat8 (EURUSD,H1) 2 input value =3.000000 Hex float8 = [46] ushort value=70 TestCastFloat8 (EURUSD,H1) 3 input value =4.000000 Hex float8 = [48] ushort value=72 TestCastFloat8 (EURUSD,H1) 4 input value =5.000000 Hex float8 = [49] ushort value=73 TestCastFloat8 (EURUSD,H1) 5 input value =6.000000 Hex float8 = [4a] ushort value=74 TestCastFloat8 (EURUSD,H1) 6 input value =7.000000 Hex float8 = [4b] ushort value=75 TestCastFloat8 (EURUSD,H1) 7 input value =8.000000 Hex float8 = [4c] ushort value=76 TestCastFloat8 (EURUSD,H1) 8 input value =8.000000 Hex float8 = [4c] ushort value=76 TestCastFloat8 (EURUSD,H1) 9 input value =10.000000 Hex float8 = [4d] ushort value=77 TestCastFloat8 (EURUSD,H1) 10 input value =12.000000 Hex float8 = [4e] ushort value=78 TestCastFloat8 (EURUSD,H1) 11 input value =12.000000 Hex float8 = [4e] ushort value=78 TestCastFloat8 (EURUSD,H1) 12 input value =12.000000 Hex float8 = [4e] ushort value=78 TestCastFloat8 (EURUSD,H1) 13 input value =14.000000 Hex float8 = [4f] ushort value=79 TestCastFloat8 (EURUSD,H1) 14 input value =16.000000 Hex float8 = [50] ushort value=80 TestCastFloat8 (EURUSD,H1) ONNX input array: [64,68,70,72,73,74,75,76,76,77,78,78,78,79,80] TestCastFloat8 (EURUSD,H1) ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,8.0,10.0,12.0,12.0,12.0,14.0,16.0] TestCastFloat8 (EURUSD,H1) 0 output float 1.000000 = [00,00,80,3f] difference=0.000000 TestCastFloat8 (EURUSD,H1) 1 output float 2.000000 = [00,00,00,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 2 output float 3.000000 = [00,00,40,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 3 output float 4.000000 = [00,00,80,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 4 output float 5.000000 = [00,00,a0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 5 output float 6.000000 = [00,00,c0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 6 output float 7.000000 = [00,00,e0,40] difference=0.000000 TestCastFloat8 (EURUSD,H1) 7 output float 8.000000 = [00,00,00,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 8 output float 8.000000 = [00,00,00,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 9 output float 10.000000 = [00,00,20,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 10 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 11 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 12 output float 12.000000 = [00,00,40,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 13 output float 14.000000 = [00,00,60,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) 14 output float 16.000000 = [00,00,80,41] difference=0.000000 TestCastFloat8 (EURUSD,H1) RunCastFloat8ToFloat(FLOAT_FP8_E5M2FNUZ): sum_error=0.000000 TestCastFloat8 (EURUSD,H1)
2. 画像の高解析度化でのONNXの使用
このセクションでは、画像の解像度を向上させるためにSRGANモデルを使用する例を検討します。
ESRGAN (Enhanced Super-Resolution Generative Adversarial Networks)は、画像の高解析度化タスクに対処するために設計された強力なニューラルネットワークアーキテクチャです。ESRGANは、解像度を高めて画質を向上させるために開発されました。これは、低解像度画像とそれに対応する高画質画像の大規模なデータセットに対してディープニューラルネットワークを訓練することで実現されます。ESRGANはGenerative Adversarial Networks (GAN)のアーキテクチャを採用しており、生成器と識別器の2つの主要コンポーネントから構成されています。生成器は高解像度の画像を生成する役割を担い、識別器は生成された画像と実際の画像を区別するように訓練されます。
ESRGANアーキテクチャの中核をなすのは残差ブロックであり、さまざまな抽象度で画像の重要な特徴を抽出保存するのに役立ちます。これにより、ネットワークは高画質画像のディテールやテクスチャーを効率的に復元することができます。
ESRGANは、超解像タスクの解決において高い品質と普遍性を達成するために、大規模な訓練データセットを必要とします。これにより、ネットワークは画像のさまざまなスタイルや特性を学習し、さまざまな型の入力データに適応できるようになります。ESRGANは、写真撮影、医療診断、映画ビデオ制作、グラフィックデザインなど、さまざまな分野で画質の向上に利用できます。その柔軟性と効率性により、画像超解像の分野では主要な手法の1つとなっています。
ESRGANは、画像処理と人工知能の分野で大きな進歩を遂げ、画像の作成と向上の新たな可能性を切り開きました。
2.1. float32によるONNXモデルの実行
サンプルを実行するには、https://github.com/amannm/super-resolution-service/blob/main/models/esrgan.onnxでファイルをダウンロードし、\MQL5\Scripts\modelsフォルダにコピーする必要があります。
ESRGAN.onnxモデルには約1200のONNX操作が含まれています。その初期操作を図12に示します。
図12:MetaEditorにおけるESRGAN.onnxモデルの説明
図13:NetronにおけるESRGAN.ONNXモデル
まずesrgan.onnxモデルを読み込み、次にBMP形式のオリジナル画像を選択して読み込みます。その後、画像は別々のRGBチャンネルに変換され、入力としてモデルに供給されます。このモデルは、画像を4倍にアップスケールする処理を実行し、その結果、アップスケールされた画像は逆変換を受け、表示用に準備されます。
表示にはCanvasライブラリを使用し、モデルの実行にはONNXRuntimeライブラリを使用します。プログラムを実行すると、アップスケールされた画像は元のファイル名に「_upscaled」が付加されたファイルに保存されます。主な機能には、画像の前処理と後処理、および画像のアップスケーリングのためのモデル実行が含まれます。
//+------------------------------------------------------------------+ //| ESRGAN.mq5 | //| Copyright 2024, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #property copyright "Copyright 2024, MetaQuotes Ltd." #property link "https://www.mql5.com" #property version "1.00" //+------------------------------------------------------------------+ //| 4x image upscaling demo using ESRGAN | //| esrgan.onnx model from | //| https://github.com/amannm/super-resolution-service/ | //+------------------------------------------------------------------+ //| Xintao Wang et al (2018) | //| ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks| //| https://arxiv.org/abs/1809.00219 | //+------------------------------------------------------------------+ #resource "models\\esrgan.onnx" as uchar ExtModel[]; #include <Canvas\Canvas.mqh> //+------------------------------------------------------------------+ //| clamp | //+------------------------------------------------------------------+ float clamp(float value, float minValue, float maxValue) { return MathMin(MathMax(value, minValue), maxValue); } //+------------------------------------------------------------------+ //| Preprocessing | //+------------------------------------------------------------------+ bool Preprocessing(float &data[],uint &image_data[],int &image_width,int &image_height) { //--- checkup if(image_height==0 || image_width==0) return(false); //--- prepare destination array with separated RGB channels for ONNX model int data_count=3*image_width*image_height; if(ArrayResize(data,data_count)!=data_count) { Print("ArrayResize failed"); return(false); } //--- converting for(int y=0; y<image_height; y++) for(int x=0; x<image_width; x++) { //--- load source RGB int offset=y*image_width+x; uint clr =image_data[offset]; uchar r =GETRGBR(clr); uchar g =GETRGBG(clr); uchar b =GETRGBB(clr); //--- store RGB components as separated channels int offset_ch1=0*image_width*image_height+offset; int offset_ch2=1*image_width*image_height+offset; int offset_ch3=2*image_width*image_height+offset; data[offset_ch1]=r/255.0f; data[offset_ch2]=g/255.0f; data[offset_ch3]=b/255.0f; } //--- return(true); } //+------------------------------------------------------------------+ //| PostProcessing | //+------------------------------------------------------------------+ bool PostProcessing(const float &data[], uint &image_data[], const int &image_width, const int &image_height) { //--- checks if(image_height == 0 || image_width == 0) return(false); int data_count=image_width*image_height; if(ArraySize(data)!=3*data_count) return(false); if(ArrayResize(image_data,data_count)!=data_count) return(false); //--- for(int y=0; y<image_height; y++) for(int x=0; x<image_width; x++) { int offset =y*image_width+x; int offset_ch1=0*image_width*image_height+offset; int offset_ch2=1*image_width*image_height+offset; int offset_ch3=2*image_width*image_height+offset; //--- rescale to [0..255] float r=clamp(data[offset_ch1]*255,0,255); float g=clamp(data[offset_ch2]*255,0,255); float b=clamp(data[offset_ch3]*255,0,255); //--- set color image_data image_data[offset]=XRGB(uchar(r),uchar(g),uchar(b)); } //--- return(true); } //+------------------------------------------------------------------+ //| ShowImage | //+------------------------------------------------------------------+ bool ShowImage(CCanvas &canvas,const string name,const int x0,const int y0,const int image_width,const int image_height, const uint &image_data[]) { if(ArraySize(image_data)==0 || name=="") return(false); //--- prepare canvas canvas.CreateBitmapLabel(name,x0,y0,image_width,image_height,COLOR_FORMAT_XRGB_NOALPHA); //--- copy image to canvas for(int y=0; y<image_height; y++) for(int x=0; x<image_width; x++) canvas.PixelSet(x,y,image_data[y*image_width+x]); //--- ready to draw canvas.Update(true); return(true); } //+------------------------------------------------------------------+ //| Script program start function | //+------------------------------------------------------------------+ int OnStart(void) { //--- select BMP from <data folder>\MQL5\Files string image_path[1]; if(FileSelectDialog("Select BMP image",NULL,"Bitmap files (*.bmp)|*.bmp",FSD_FILE_MUST_EXIST,image_path,"lenna-original4.bmp")!=1) { Print("file not selected"); return(-1); } //--- load BMP into array uint image_data[]; int image_width; int image_height; if(!CCanvas::LoadBitmap(image_path[0],image_data,image_width,image_height)) { PrintFormat("CCanvas::LoadBitmap failed with error %d",GetLastError()); return(-1); } //--- convert RGB image to separated RGB channels float input_data[]; Preprocessing(input_data,image_data,image_width,image_height); PrintFormat("input array size=%d",ArraySize(input_data)); //--- load model long model_handle=OnnxCreateFromBuffer(ExtModel,ONNX_DEFAULT); if(model_handle==INVALID_HANDLE) { PrintFormat("OnnxCreate error %d",GetLastError()); return(-1); } PrintFormat("model loaded successfully"); PrintFormat("original: width=%d, height=%d Size=%d",image_width,image_height,ArraySize(image_data)); //--- set input shape ulong input_shape[]={1,3,image_height,image_width}; if(!OnnxSetInputShape(model_handle,0,input_shape)) { PrintFormat("error in OnnxSetInputShape. error code=%d",GetLastError()); OnnxRelease(model_handle); return(-1); } //--- upscaled image size int new_image_width =4*image_width; int new_image_height=4*image_height; ulong output_shape[]= {1,3,new_image_height,new_image_width}; if(!OnnxSetOutputShape(model_handle,0,output_shape)) { PrintFormat("error in OnnxSetOutputShape. error code=%d",GetLastError()); OnnxRelease(model_handle); return(-1); } //--- run the model float output_data[]; int new_data_count=3*new_image_width*new_image_height; if(ArrayResize(output_data,new_data_count)!=new_data_count) { OnnxRelease(model_handle); return(-1); } if(!OnnxRun(model_handle,ONNX_DEBUG_LOGS,input_data,output_data)) { PrintFormat("error in OnnxRun. error code=%d",GetLastError()); OnnxRelease(model_handle); return(-1); } Print("model successfully executed, output data size ",ArraySize(output_data)); OnnxRelease(model_handle); //--- postprocessing uint new_image[]; PostProcessing(output_data,new_image,new_image_width,new_image_height); //--- show images CCanvas canvas_original,canvas_scaled; ShowImage(canvas_original,"original_image",new_image_width,0,image_width,image_height,image_data); ShowImage(canvas_scaled,"upscaled_image",0,0,new_image_width,new_image_height,new_image); //--- save upscaled image StringReplace(image_path[0],".bmp","_upscaled.bmp"); Print(ResourceSave(canvas_scaled.ResourceName(),image_path[0])); //--- while(!IsStopped()) Sleep(100); return(0); } //+------------------------------------------------------------------+
以下が出力です。
図14:ESRGAN.onnxモデルの実行結果(160x200->640x800)
この例では、160x200の画像をESRGAN.onnxモデルを使用して4倍(640x800)に拡大しました。
2.2.float16を使用したONNXモデルの実行例
モデルをfloat16に変換するには、「Create Float16 and Mixed Precision Models」で説明されている方法を使用します。
# Copyright 2024, MetaQuotes Ltd. # https://www.mql5.com import onnx from onnxconverter_common import float16 from sys import argv # Define the path for saving the model data_path = argv[0] last_index = data_path.rfind("\\") + 1 data_path = data_path[0:last_index] # convert the model to float16 model_path = data_path+'\\models\\esrgan.onnx' modelfp16_path = data_path+'\\models\\esrgan_float16.onnx' model = onnx.load(model_path) model_fp16 = float16.convert_float_to_float16(model) onnx.save(model_fp16, modelfp16_path)
変換後、ファイルサイズは半分になりました(64MBから32MB)。
コードの変更は最小限です。
//+------------------------------------------------------------------+ //| ESRGAN_float16.mq5 | //| Copyright 2024, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #property copyright "Copyright 2024, MetaQuotes Ltd." #property link "https://www.mql5.com" #property version "1.00" //+------------------------------------------------------------------+ //| 4x image upscaling demo using ESRGAN | //| esrgan.onnx model from | //| https://github.com/amannm/super-resolution-service/ | //+------------------------------------------------------------------+ //| Xintao Wang et al (2018) | //| ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks| //| https://arxiv.org/abs/1809.00219 | //+------------------------------------------------------------------+ #resource "models\\esrgan_float16.onnx" as uchar ExtModel[]; #include <Canvas\Canvas.mqh> //+------------------------------------------------------------------+ //| clamp | //+------------------------------------------------------------------+ float clamp(float value, float minValue, float maxValue) { return MathMin(MathMax(value, minValue), maxValue); } //+------------------------------------------------------------------+ //| Preprocessing | //+------------------------------------------------------------------+ bool Preprocessing(float &data[],uint &image_data[],int &image_width,int &image_height) { //--- checkup if(image_height==0 || image_width==0) return(false); //--- prepare destination array with separated RGB channels for ONNX model int data_count=3*image_width*image_height; if(ArrayResize(data,data_count)!=data_count) { Print("ArrayResize failed"); return(false); } //--- converting for(int y=0; y<image_height; y++) for(int x=0; x<image_width; x++) { //--- load source RGB int offset=y*image_width+x; uint clr =image_data[offset]; uchar r =GETRGBR(clr); uchar g =GETRGBG(clr); uchar b =GETRGBB(clr); //--- store RGB components as separated channels int offset_ch1=0*image_width*image_height+offset; int offset_ch2=1*image_width*image_height+offset; int offset_ch3=2*image_width*image_height+offset; data[offset_ch1]=r/255.0f; data[offset_ch2]=g/255.0f; data[offset_ch3]=b/255.0f; } //--- return(true); } //+------------------------------------------------------------------+ //| PostProcessing | //+------------------------------------------------------------------+ bool PostProcessing(const float &data[], uint &image_data[], const int &image_width, const int &image_height) { //--- checks if(image_height == 0 || image_width == 0) return(false); int data_count=image_width*image_height; if(ArraySize(data)!=3*data_count) return(false); if(ArrayResize(image_data,data_count)!=data_count) return(false); //--- for(int y=0; y<image_height; y++) for(int x=0; x<image_width; x++) { int offset =y*image_width+x; int offset_ch1=0*image_width*image_height+offset; int offset_ch2=1*image_width*image_height+offset; int offset_ch3=2*image_width*image_height+offset; //--- rescale to [0..255] float r=clamp(data[offset_ch1]*255,0,255); float g=clamp(data[offset_ch2]*255,0,255); float b=clamp(data[offset_ch3]*255,0,255); //--- set color image_data image_data[offset]=XRGB(uchar(r),uchar(g),uchar(b)); } //--- return(true); } //+------------------------------------------------------------------+ //| ShowImage | //+------------------------------------------------------------------+ bool ShowImage(CCanvas &canvas,const string name,const int x0,const int y0,const int image_width,const int image_height, const uint &image_data[]) { if(ArraySize(image_data)==0 || name=="") return(false); //--- prepare canvas canvas.CreateBitmapLabel(name,x0,y0,image_width,image_height,COLOR_FORMAT_XRGB_NOALPHA); //--- copy image to canvas for(int y=0; y<image_height; y++) for(int x=0; x<image_width; x++) canvas.PixelSet(x,y,image_data[y*image_width+x]); //--- ready to draw canvas.Update(true); return(true); } //+------------------------------------------------------------------+ //| Script program start function | //+------------------------------------------------------------------+ int OnStart(void) { //--- select BMP from <data folder>\MQL5\Files string image_path[1]; if(FileSelectDialog("Select BMP image",NULL,"Bitmap files (*.bmp)|*.bmp",FSD_FILE_MUST_EXIST,image_path,"lenna.bmp")!=1) { Print("file not selected"); return(-1); } //--- load BMP into array uint image_data[]; int image_width; int image_height; if(!CCanvas::LoadBitmap(image_path[0],image_data,image_width,image_height)) { PrintFormat("CCanvas::LoadBitmap failed with error %d",GetLastError()); return(-1); } //--- convert RGB image to separated RGB channels float input_data[]; Preprocessing(input_data,image_data,image_width,image_height); PrintFormat("input array size=%d",ArraySize(input_data)); ushort input_data_float16[]; if(!ArrayToFP16(input_data_float16,input_data,FLOAT_FP16)) { Print("error in ArrayToFP16. error code=",GetLastError()); return(false); } //--- load model long model_handle=OnnxCreateFromBuffer(ExtModel,ONNX_DEFAULT); if(model_handle==INVALID_HANDLE) { PrintFormat("OnnxCreate error %d",GetLastError()); return(-1); } PrintFormat("model loaded successfully"); PrintFormat("original: width=%d, height=%d Size=%d",image_width,image_height,ArraySize(image_data)); //--- set input shape ulong input_shape[]={1,3,image_height,image_width}; if(!OnnxSetInputShape(model_handle,0,input_shape)) { PrintFormat("error in OnnxSetInputShape. error code=%d",GetLastError()); OnnxRelease(model_handle); return(-1); } //--- upscaled image size int new_image_width =4*image_width; int new_image_height=4*image_height; ulong output_shape[]= {1,3,new_image_height,new_image_width}; if(!OnnxSetOutputShape(model_handle,0,output_shape)) { PrintFormat("error in OnnxSetOutputShape. error code=%d",GetLastError()); OnnxRelease(model_handle); return(-1); } //--- run the model float output_data[]; ushort output_data_float16[]; int new_data_count=3*new_image_width*new_image_height; if(ArrayResize(output_data_float16,new_data_count)!=new_data_count) { OnnxRelease(model_handle); return(-1); } if(!OnnxRun(model_handle,ONNX_NO_CONVERSION,input_data_float16,output_data_float16)) { PrintFormat("error in OnnxRun. error code=%d",GetLastError()); OnnxRelease(model_handle); return(-1); } Print("model successfully executed, output data size ",ArraySize(output_data)); OnnxRelease(model_handle); if(!ArrayFromFP16(output_data,output_data_float16,FLOAT_FP16)) { Print("error in ArrayFromFP16. error code=",GetLastError()); return(false); } //--- postprocessing uint new_image[]; PostProcessing(output_data,new_image,new_image_width,new_image_height); //--- show images CCanvas canvas_original,canvas_scaled; ShowImage(canvas_original,"original_image",new_image_width,0,image_width,image_height,image_data); ShowImage(canvas_scaled,"upscaled_image",0,0,new_image_width,new_image_height,new_image); //--- save upscaled image StringReplace(image_path[0],".bmp","_upscaled.bmp"); Print(ResourceSave(canvas_scaled.ResourceName(),image_path[0])); //--- while(!IsStopped()) Sleep(100); return(0); } //+------------------------------------------------------------------+
float16形式に変換されたモデルを実行するために必要なコードの変更は、色で強調表示されています。
以下が出力です。
図15:ESRGAN_float16.onnxモデルの実行結果(160x200->640x800)
このように、float32の代わりにfloat16を使うことで、ONNXモデルファイルのサイズを半分に減らすことができます(64MBから32MB)。
数値がfloat16のモデルを実行しても画質は変わらず、視覚的に違いを見つけるのは難しいです。
図16:floatとfloat16のESRGANモデル演算結果の比較
コードの変更は最小限であり、入出力データの変換に注意を払うだけでよいです。
この場合、float16に変換しても、モデルのパフォーマンスに大きな変化はありませんでした。ただし、金融データを分析する際には、可能な限り精度の高い計算に努めることが不可欠です。
終わりに
浮動小数点数の新しいデータ型を使用することで、品質を大きく損なうことなくONNXモデルのサイズを小さくすることができます。
ArrayToFP16/ArrayFromFP16およびArrayToFP8/ArrayFromFP8の変換関数を使用すると、データの前処理と後処理が大幅に簡素化されます。
変換されたONNXモデルで動作するために必要なコード変更は最小限です。
MetaQuotes Ltdによってロシア語から翻訳されました。
元の記事: https://www.mql5.com/ru/articles/14330




- 無料取引アプリ
- 8千を超えるシグナルをコピー
- 金融ニュースで金融マーケットを探索