GeMM
Phương thức GeMM (Nhân Ma trận Tổng quát) thực hiện phép nhân tổng quát của hai ma trận. Phép toán được định nghĩa là C ← α A B + β C
, trong đó ma trận A và B có thể được chuyển vị tùy chọn. Với phép nhân ma trận thông thường AB (MatMul), hệ số vô hướng alpha
được giả định bằng 1 và beta
bằng 0.
Sự khác biệt chính giữa GeMM và MatMul về hiệu suất là MatMul luôn tạo ra một đối tượng ma trận/vector mới, trong khi GeMM hoạt động với một đối tượng ma trận hiện có và không tạo lại nó. Do đó, khi sử dụng GeMM và cấp phát trước bộ nhớ cho ma trận tương ứng, thì khi làm việc với cùng kích thước ma trận, sẽ không có việc cấp phát lại bộ nhớ. Đây có thể là một lợi thế quan trọng của GeMM trong tính toán hàng loạt, ví dụ như khi chạy tối ưu hóa trong bộ kiểm tra chiến lược hoặc khi huấn luyện mạng nơ-ron.
Tương tự như MatMul, GeMM cũng có 4 biến thể. Tuy nhiên, ngữ nghĩa của biến thể thứ tư đã được sửa đổi để cho phép nhân một vector dọc với một vector ngang.
Trong một đối tượng ma trận/vector hiện có, không cần phải cấp phát trước bộ nhớ. Bộ nhớ sẽ được cấp phát và điền bằng số không tại lần gọi GeMM đầu tiên.
Nhân ma trận với ma trận: matrix C[M][N] = α * (matrix A[M][K] * matrix B[K][N]) + β * matrix C[M][N]
bool matrix::GeMM(
const matrix &A, // ma trận đầu tiên
const matrix &B, // ma trận thứ hai
double alpha, // hệ số alpha cho tích AB
double beta, // hệ số beta cho ma trận C
uint flags // tổ hợp các giá trị ENUM_GEMM (bitwise OR), xác định liệu ma trận A, B và C có được chuyển vị hay không
);
2
3
4
5
6
7
Nhân vector với ma trận: vector C[N] = α * (vector A[K] * matrix B[K][N]) + β * vector C[N]
bool vector::GeMM(
const vector &A, // vector ngang
const matrix &B, // ma trận
double alpha, // hệ số alpha cho tích AB
double beta, // hệ số beta cho vector C
uint flags // giá trị liệt kê ENUM_GEMM xác định liệu ma trận A có được chuyển vị hay không
);
2
3
4
5
6
7
Nhân ma trận với vector: vector C[M] = α * (matrix A[M][K] * vector B[K]) + β * vector C[M]
bool vector::GeMM(
const matrix &A, // ma trận
const vector &B, // vector dọc
double alpha, // hệ số alpha cho tích AB
double beta, // hệ số beta cho vector C
uint flags // giá trị liệt kê ENUM_GEMM xác định liệu ma trận B có được chuyển vị hay không
);
2
3
4
5
6
7
Nhân vector với vector: matrix C[M][N] = α * (vector A[M] * vector B[N]) + β * matrix C[M][N]
. Biến thể này trả về một ma trận, không giống MatMul nơi nó trả về một giá trị vô hướng.
bool matrix::GeMM(
const vector &A, // vector đầu tiên
const vector &B, // vector thứ hai
double alpha, // hệ số alpha cho tích AB
double beta, // hệ số beta cho ma trận C
uint flags // giá trị liệt kê ENUM_GEMM xác định liệu ma trận C có được chuyển vị hay không
);
2
3
4
5
6
7
Tham số
A
[in] Ma trận hoặc vector.
B
[in] Ma trận hoặc vector.
alpha
[in] Hệ số alpha cho tích AB.
beta
[in] Hệ số beta cho ma trận C kết quả.
flags
[in] Giá trị liệt kê ENUM_GEMM xác định liệu ma trận A, B và C có được chuyển vị hay không.
Giá trị trả về
Trả về true
nếu thành công hoặc false
nếu không.
ENUM_GEMM
Liệt kê các cờ cho phương thức GeMM.
ID | Mô tả |
---|---|
TRANSP_A | Sử dụng ma trận A chuyển vị |
TRANSP_B | Sử dụng ma trận B chuyển vị |
TRANSP_C | Sử dụng ma trận C chuyển vị |
Ghi chú
Ma trận và vector của các kiểu float, double và complex có thể được sử dụng làm tham số A và B. Các biến thể mẫu của phương thức GeMM bao gồm:
bool matrix<T>::GeMM(const matrix<T> &A, const matrix<T> &B, T alpha, T beta, ulong flags);
bool matrix<T>::GeMM(const vector<T> &A, const vector<T> &B, T alpha, T beta, ulong flags);
bool vector<T>::GeMM(const vector<T> &A, const matrix<T> &B, T alpha, T beta, ulong flags);
bool vector<T>::GeMM(const matrix<T> &A, const vector<T> &B, T alpha, T beta, ulong flags);
2
3
4
5
Về cơ bản, hàm nhân ma trận tổng quát được mô tả như sau:
C[m,n] = α * Sum(A[m,k] * B[k,n]) + β * C[m,n]
với các kích thước sau: ma trận A là M x K, ma trận B là K x N và ma trận C là M x N.
Do đó, các ma trận cần tương thích để nhân, tức là số cột của ma trận đầu tiên phải bằng số hàng của ma trận thứ hai. Phép nhân ma trận không có tính giao hoán: kết quả của việc nhân ma trận đầu tiên với ma trận thứ hai không bằng kết quả của việc nhân ma trận thứ hai với ma trận đầu tiên trong trường hợp tổng quát.
Ví dụ:
void OnStart()
{
vector vector_a= {1, 2, 3, 4, 5};
vector vector_b= {4, 3, 2, 1};
matrix matrix_c;
//--- tính GeMM cho hai vector
matrix_c.GeMM(vector_a, vector_b, 1, 0);
Print("matrix_c:\n ", matrix_c, "\n");
/*
matrix_c:
[[4,3,2,1]
[8,6,4,2]
[12,9,6,3]
[16,12,8,4]
[20,15,10,5]]
*/
//--- tạo ma trận từ vector
matrix matrix_a(5, 1);
matrix matrix_b(1, 4);
matrix_a.Col(vector_a, 0);
matrix_b.Row(vector_b, 0);
Print("matrix_a:\n ", matrix_a);
Print("matrix_b:\n ", matrix_b);
/*
matrix_a:
[[1]
[2]
[3]
[4]
[5]]
matrix_b:
[[4,3,2,1]]
*/
//-- tính GeMM cho hai ma trận và nhận cùng kết quả
matrix_c.GeMM(matrix_a, matrix_b, 1, 0);
Print("matrix_c:\n ", matrix_c);
/*
matrix_c:
[[4,3,2,1]
[8,6,4,2]
[12,9,6,3]
[16,12,8,4]
[20,15,10,5]]
*/
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
Xem thêm