LossGradient

손실 함수 기울기의 벡터 또는 행렬을 계산합니다.

vector vector::LossGradient(
  const vector&       vect_true,     // 참인 값의 벡터
  ENUM_LOSS_FUNCTION  loss,          // 손실 함수 종류
   ...                               // 추가 매개변수
   );
 
 
matrix matrix::LossGradient(
  const matrix&       matrix_true,   // 참인 값의 행렬
  ENUM_LOSS_FUNCTION  loss,          // 손실 함수
   );
 
 
matrix matrix::LossGradient(
  const matrix&       matrix_true,   // 참인 값의 행렬
  ENUM_LOSS_FUNCTION  loss,          // 손실 함수
  ENUM_MATRIX_AXIS    axis,          // axis
   ...                               // 추가 매개변수
   );

매개 변수

vect_true/matrix_true

 [in] 참인 값의 행렬 혹은 벡터.

loss

【in】ENUM_LOSS_FUNCTION 열거로부터의 손실 함수.

axis

[in]ENUM_MATRIX_AXIS 열거형 값(AXIS_HORZ — 가로축, AXIS_VERT — 세로축).

...

【in】 추가 매개변수 'delta'는 Hubert 손실 함수(LOSS_HUBER)에서만 사용할 수 있습니다.

Return Value

손실 함수 기울기 값의 벡터 또는 행렬. 그래디언트는 주어진 지점에서 손실 함수의 dx(x는 예측값)에 대한 편도함수입니다.

참조

기울기는 신경망에서 모델을 훈련할 때 역전파 중에 가중치 행렬의 가중치를 조정하는 데 사용됩니다.

신경망은 손실 함수가 사용되는 학습 샘플의 오류를 최소화하는 알고리즘을 찾는 것을 목표로 합니다.

문제에 따라 각기 다른 손실 함수가 사용됩니다. 예를 들어 평균 제곱 오차(Mean Squared Error)(MSE)는 회귀 문제에 사용되며 Binary Cross-Entropy(BCE)는 이진 분류 목적으로 사용됩니다.

손실 함수 기울기 계산의 예

   matrixf y_true={{ 1234 },
                   { 5678 },
                   { 9,10,11,12 }};
   matrixf y_pred={{ 1234 },
                   {11,1098 },
                   { 567,12 }};
   matrixf loss_gradient =y_pred.LossGradient(y_true,LOSS_MAE);
   matrixf loss_gradienth=y_pred.LossGradient(y_true,LOSS_MAE,AXIS_HORZ);
   matrixf loss_gradientv=y_pred.LossGradient(y_true,LOSS_MAE,AXIS_VERT);
   Print("loss gradients\n",loss_gradient);
   Print("loss gradients on horizontal axis\n",loss_gradienth);
   Print("loss gradients on vertical axis\n",loss_gradientv);
 
/* Result
   loss gradients
   [[0,0,0,0]
    [0.083333336,0.083333336,0.083333336,0]
    [-0.083333336,-0.083333336,-0.083333336,0]]
   loss gradients on horizontal axis
   [[0,0,0,0]
    [0.33333334,0.33333334,0.33333334,0]
    [-0.33333334,-0.33333334,-0.33333334,0]]
   loss gradients on vertical axis
   [[0,0,0,0]
    [0.25,0.25,0.25,0]
    [-0.25,-0.25,-0.25,0]]
*/