带限矩阵方程组求解


题目

by 李锦朋

描述

解 \(AX=Z\), 其中

\[A = \begin{cases} \begin{bmatrix} b_1 & c_1 \\ a_2 & \ddots & \ddots \\ & \ddots & \ddots & c_{n-1} \\ & & a_n & b_n \end{bmatrix} &, p = 3; \\ \\ \begin{bmatrix} c_1 & d_1 & e_1 \\ b_2 & \ddots & \ddots & \ddots \\ a_3 & \ddots & \ddots & \ddots & e_{n-1} \\ & \ddots & \ddots & \ddots & d_{n-1} \\ & & a_n & b_n & c_n \end{bmatrix} &, p = 5; \end{cases}\] \[Z = \left[z_1 \cdots z_m \right].\]

保留 4 位小数.

IO格式

输入:

p
A的阶数n Z的列数m
// p switch
// 3 =>
    c[1] ... c[n-1]
    b[1] ... b[n]
    a[2] ... a[n]
// 5 =>
    e[1] ... e[n-2]
    d[1] ... d[n-1]
    c[1] ... c[n]
    b[2] ... b[n]
    a[3] ... a[n]
z1[1] ... z1[n]
...
zm[1] ... zm[n]

输出:

x1[1] ... x1[n]
...
xm[1] ... xm[n]

样例

输入1:

3
3 2
44 62 
44 43 30 
3 34 
27 63 53 
14 52 19

输出1:

-0.9846 1.5983 -0.0447 
0.7073 -0.3892 1.0744

输入2:

5
6 1
3 6 9 12
2 11 32 65 110
1 8 36 88 164 263
2 22 67 136 229
3 20 49 90
1 2 3 4 5 6

输出2:

-2.7033 -31.2044 22.0374 2.4385 -19.3077 16.0000 

解答

思路

根据 \(LU = A\) 寻找规律快速 LU 分解, 其中 \(L\) 和 \(U\) 均原地存储在 \(A\) 中; 然后利用 \(Ly = Z\)、\(Ux = y\) 解 \(X\).

测试时的用例: 矩阵计算器

参考:

代码

这段代码的三对角和五对角有着截然不同的下标处理风格. 主要是因为三对角的 mat 为紧凑存储, 其下标与矩阵中的下标有着明显的不同; 而五对角为了减少脑力消耗, 采取了非紧凑存储, 即 mat[0 ... 4] 均为 n+1 长度, 有着和矩阵中相同的下标.

#include <cstdio>
using namespace std;

#define prf printf
#define scf scanf
using dbl = double;
using arr = dbl*;

int p, n, m;

struct Matrix {
    arr* mat;
    int* lens;
    Matrix() : mat(new arr[p]), lens(new int[p]) {
        if (p == 3) {
            lens[0] = lens[2] = n - 1;
            lens[1] = n;
            for (int i = 0; i < p; i++) mat[i] = new dbl[lens[i]];
        } else if (p == 5) {
            lens[0] = lens[1] = lens[2] = lens[3] = lens[4] = n + 1;
            for (int i = 0; i < p; i++) mat[i] = new dbl[lens[i]];
        }
    };
    ~Matrix() {
        for (int i = 0; i < p; i++) delete[] mat[i];
        delete[] mat;
        delete[] lens;
    };
    void in() {
        if (p == 3) {
            for (int i = 0; i < p; i++)
                for (int j = 0; j < lens[i]; j++) scf("%lf", &mat[i][j]);
        } else if (p == 5) {
            for (int i = 1; i <= n - 2; i++) scf("%lf", &mat[0][i]);
            for (int i = 1; i <= n - 1; i++) scf("%lf", &mat[1][i]);
            for (int i = 1; i <= n; i++) scf("%lf", &mat[2][i]);
            for (int i = 2; i <= n; i++) scf("%lf", &mat[3][i]);
            for (int i = 3; i <= n; i++) scf("%lf", &mat[4][i]);
        }
    };
    void lu() {
        if (p == 3) {
            for (int i = 1; i < n; i++) {
                mat[2][i - 1] = mat[2][i - 1] / mat[1][i - 1];
                mat[1][i] = mat[1][i] - mat[2][i - 1] * mat[0][i - 1];
            }
        } else if (p == 5) {
            mat[1][1] = mat[1][1] / mat[2][1];
            mat[0][1] = mat[0][1] / mat[2][1];
            mat[2][2] = mat[2][2] - mat[3][2] * mat[1][1];
            mat[1][2] = (mat[1][2] - mat[3][2] * mat[0][1]) / mat[2][2];
            mat[0][2] = mat[0][2] / mat[2][2];
            for (int i = 3; i <= n; i++) {
                mat[3][i] = mat[3][i] - mat[4][i] * mat[1][i - 2];
                mat[2][i] = mat[2][i] - mat[4][i] * mat[0][i - 2] -
                            mat[3][i] * mat[1][i - 1];
                mat[1][i] = (mat[1][i] - mat[3][i] * mat[0][i - 1]) / mat[2][i];
                mat[0][i] = mat[0][i] / mat[2][i];
            }
        }
    };
};

struct Vector {
    dbl* vec;
    Vector() : vec(new dbl[n + 1]){};
    ~Vector() { delete[] vec; };
    void in() {
        if (p == 3) {
            for (int i = 0; i < n; i++) scf("%lf", &vec[i]);
        } else if (p == 5) {
            for (int i = 1; i <= n; i++) scf("%lf", &vec[i]);
        }
    };
    void out() {
        if (p == 3) {
            for (int i = 0; i < n; i++) prf("%.4lf ", vec[i]);
        } else if (p == 5) {
            for (int i = 1; i <= n; i++) prf("%.4lf ", vec[i]);
        }
        prf("\n");
    };
    void solve(Matrix& A) {
        if (p == 3) {
            // 解Ly = z
            for (int i = 1; i < n; i++)
                vec[i] = vec[i] - A.mat[2][i - 1] * vec[i - 1];
            // 解Ux = y
            vec[n - 1] = vec[n - 1] / A.mat[1][n - 1];
            for (int i = n - 2; i >= 0; i--)
                vec[i] = (vec[i] - A.mat[0][i] * vec[i + 1]) / A.mat[1][i];
        } else if (p == 5) {
            // 解Ly = z
            vec[1] = vec[1] / A.mat[2][1];
            vec[2] = (vec[2] - A.mat[3][2] * vec[1]) / A.mat[2][2];
            for (int i = 3; i <= n; i++)
                vec[i] = (vec[i] - A.mat[4][i] * vec[i - 2] -
                          A.mat[3][i] * vec[i - 1]) /
                         A.mat[2][i];
            // 解Ux = y
            vec[n - 1] = vec[n - 1] - A.mat[1][n - 1] * vec[n];
            for (int i = n - 2; i >= 1; i--)
                vec[i] = vec[i] - A.mat[1][i] * vec[i + 1] -
                         A.mat[0][i] * vec[i + 2];
        }
    };
};

int main() {
    scf("%d%d%d", &p, &n, &m);
    Matrix A;
    Vector z;
    A.in();
    A.lu();
    for (int i = 0; i < m; i++) {
        z.in();
        z.solve(A);
        z.out();
    }
    return 0;
}

时空消耗

1	Accepted	0 ms	828 KB
2	Accepted	0 ms	832 KB
3	Accepted	0 ms	828 KB
4	Accepted	0 ms	832 KB
5	Accepted	40 ms	864 KB
6	Accepted	84 ms	1144 KB
7	Accepted	408 ms	984 KB
8	Accepted	200 ms	864 KB
9	Accepted	164 ms	892 KB
10	Accepted	432 ms	1296 KB