PrecisionRecall
Tính toán các giá trị để xây dựng đường cong precision-recall. Tương tự như ClassificationScore
, phương thức này được áp dụng cho vector chứa các giá trị thực.
bool vector::PrecisionRecall(
const matrix& pred_scores, // ma trận chứa phân phối xác suất cho mỗi lớp
const ENUM_ENUM_AVERAGE_MODE mode, // chế độ trung bình
matrix& precision, // giá trị precision được tính toán cho mỗi giá trị ngưỡng
matrix& recall, // giá trị recall được tính toán cho mỗi giá trị ngưỡng
matrix& thresholds, // các giá trị ngưỡng được sắp xếp theo thứ tự giảm dần
);
2
3
4
5
6
7
Parameters
pred_scores
[in] Một ma trận chứa tập hợp các vector ngang với xác suất cho mỗi lớp. Số hàng của ma trận phải tương ứng với kích thước của vector chứa các giá trị thực.
mode
[in] Chế độ trung bình từ bảng liệt kê ENUM_AVERAGE_MODE
. Chỉ sử dụng AVERAGE_NONE
, AVERAGE_BINARY
và AVERAGE_MICRO
.
precision
[out] Một ma trận chứa các giá trị đường cong precision được tính toán. Nếu không áp dụng trung bình (AVERAGE_NONE
), số hàng trong ma trận tương ứng với số lớp của mô hình. Số cột tương ứng với kích thước của vector chứa các giá trị thực (hoặc số hàng trong ma trận phân phối xác suất pred_scores
). Trong trường hợp trung bình vi mô, số hàng trong ma trận tương ứng với tổng số giá trị ngưỡng, không bao gồm các giá trị trùng lặp.
recall
[out] Một ma trận chứa các giá trị đường cong recall được tính toán.
thresholds
[out] Ma trận ngưỡng thu được bằng cách sắp xếp ma trận xác suất.
Note
Xem ghi chú cho phương thức ClassificationScore
.
Ví dụ
Một ví dụ về việc thu thập thống kê từ mô hình mnist.onnx
(độ chính xác 99%).
//--- dữ liệu cho các chỉ số phân loại
vectorf y_true(images);
vectorf y_pred(images);
matrixf y_scores(images,10);
//--- đầu vào-đầu ra
matrixf image(28,28);
vectorf result(10);
//--- kiểm tra
for(int test=0; test<images; test++)
{
image=test_data[test].image;
if(!OnnxRun(model,ONNX_DEFAULT,image,result))
{
Print("OnnxRun error ",GetLastError());
break;
}
result.Activation(result,AF_SOFTMAX);
//--- thu thập dữ liệu
y_true[test]=(float)test_data[test].label;
y_pred[test]=(float)result.ArgMax();
y_scores.Row(result,test);
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
vectorf accuracy=y_pred.ClassificationMetric(y_true,CLASSIFICATION_ACCURACY);
PrintFormat("accuracy=%f",accuracy[0]);
accuracy=0.989000
2
3
4
Ví dụ về việc vẽ biểu đồ precision-recall, trong đó các giá trị precision được vẽ trên trục y và các giá trị recall được vẽ trên trục x. Ngoài ra, các biểu đồ precision và recall cũng được vẽ riêng biệt, với các giá trị ngưỡng được vẽ trên trục x.
if(y_true.PrecisionRecall(y_scores,AVERAGE_MICRO,mat_precision,mat_recall,mat_thres))
{
double precision[],recall[],thres[];
ArrayResize(precision,mat_thres.Cols());
ArrayResize(recall,mat_thres.Cols());
ArrayResize(thres,mat_thres.Cols());
for(uint i=0; i<thres.Size(); i++)
{
precision[i]=mat_precision[0][i];
recall[i]=mat_recall[0][i];
thres[i]=mat_thres[0][i];
}
thres[0]=thres[1]+0.001;
PlotCurve("Precision-Recall curve (micro average)","p-r","",recall,precision);
Plot2Curves("Precision-Recall (micro average)","precision","recall",thres,precision,recall);
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
Đường cong kết quả: