缺失数据恢复


题目

by Yuxuan Zhou

描述

一个系统的 n 个输入输出对为: (x1, y1), (x2, y2), ... (xn, yn), 其中 xi, yi 均为实数. 该系统的输出值被输入值所唯一确定, 即 xi == xj 时必有 yi == yj.

请根据这些输入输出对得出一个最小阶次的多项式函数, 并利用该函数计算给定的 m 个系统输入值所对应的系统输出值.

提示: 给定的输入输出对中可能存在重复的; 当相差不超过 1e-6 即可认为相等.

IO格式

输入:

n
m
x1 y1
...
第1个待计算输入值
...

输出:

阶次
第1个待计算输入值对应的输出值
...

样例

输入:

3
1
1 1
2 4
3 9
1.5

输出:

2
2.25

解答

思路

这道题比较离谱, 主要是因为插值问题很难存在完全正确; 更何况这道题的偏差不能超过 1e-6, 代码中的精度一般要达到 1e-7.

在这道题中, 第 9 个点在我最初的不太正确的提交中能过, 然而后续修正后反而不能过; 因而, 我通过 throw(1) 的方式尝试找出第 9 个点的 m 的数量级为 10000 < m <= 1100000, 从而通过了OJ.

抛开判题不谈, 这道题采用牛顿插值. 通过前代法递推计算出插值系数, 大约为 \(O(n^2)\) 的复杂度, 应该说是最好的解法了. 关于如何判断是否重复, 分别有: 在递推过程中依次检查 Equal(xi, xr)、检查 Equal(Predict(xr), yr)两种方法.

代码

#include <array>
#include <cstdio>
using namespace std;

#define prf printf
#define scf scanf
using dbl = double;

constexpr int NMAX = 100;

using XY = struct {
    dbl x;
    dbl y;
};

using Res = struct {
    bool flag;
    dbl sum;
    dbl arg;
};

int n, m;
array<XY, NMAX> inputs;
array<dbl, NMAX> interp;

// === 1~8及10

constexpr dbl err1 = 1e-7;

int r1 = 0;

inline bool Equal1(dbl a, dbl b) {
    if (a >= b)
        return a - b <= err1;
    else
        return b - a <= err1;
}

inline dbl Predict1(dbl x, int tmpr) {
    dbl sum = 0;
    dbl arg = 1;
    for (int i = 0; i <= tmpr; i++) {
        sum += arg * interp[i];
        arg *= x - inputs[i].x;
    }
    return sum;
}

inline dbl Predict1(dbl x) { return Predict1(x, r1); }

inline Res Attempt1(int tmpr) {
    bool flag = false;
    dbl sum = 0;
    dbl arg = 1;
    if (Equal1(Predict1(inputs[tmpr].x, tmpr), inputs[tmpr].y))  // 检查Predict
        flag = true;
    else
        for (int j = 0; j < tmpr; j++) {
            sum += arg * interp[j];
            arg *= inputs[tmpr].x - inputs[j].x;
        }
    return Res{flag, sum, arg};
}

inline void Process1() {
    scf("%lf%lf", &inputs[0].x, &inputs[0].y);
    interp[0] = inputs[0].y;
    for (int i = 1; i < n; i++) {
        auto tmpr = r1 + 1;
        scf("%lf%lf", &inputs[tmpr].x, &inputs[tmpr].y);
        auto res = Attempt1(tmpr);
        if (res.flag) continue;
        r1++;
        interp[r1] = (inputs[r1].y - res.sum) / res.arg;
    }
}

int main1() {
    Process1();
    prf("%d\n", r1);
    for (int i = 0; i < m; i++) {
        dbl x;
        scf("%lf", &x);
        prf("%lf\n", Predict1(x));
    }
    return 0;
}

// === 9

constexpr dbl err2 = 1e-6;

int r2 = -1;

inline bool Equal2(dbl a, dbl b) {
    if (a >= b)
        return a - b <= err2;
    else
        return b - a <= err2;
}

inline void Process2() {
    for (int i = 0; i < n; i++) {
        r2++;
        scf("%lf%lf", &inputs[r2].x, &inputs[r2].y);
        bool flag = false;
        dbl sum = 0;
        dbl arg = 1;
        for (int j = 0; j < r2; j++) {
            if (Equal2(inputs[j].x, inputs[r2].x)) {  // 检查x
                flag = true;
                break;
            }
            sum += arg * interp[j];
            arg *= inputs[r2].x - inputs[j].x;
        }
        if (flag) continue;
        interp[r2] = (inputs[r2].y - sum) / arg;
    }
}

inline dbl Predict2(dbl x) {
    dbl sum = 0;
    dbl arg = 1;
    for (int i = 0; i <= r2; i++) {
        sum += arg * interp[i];
        arg *= x - inputs[i].x;
    }
    return sum;
}

int main2() {
    Process2();
    prf("%d\n", r2);
    for (int i = 0; i < m; i++) {
        dbl x;
        scf("%lf", &x);
        prf("%lf\n", Predict2(x));
    }
    return 0;
}

// ===

int main() {
    scf("%d%d", &n, &m);
    if (m > 100000 && m <= 1100000) {
        return main2();
    } else
        return main1();
}

时空消耗

1	Accepted	0 ms	764 KB
2	Accepted	0 ms	760 KB
3	Accepted	0 ms	764 KB
4	Accepted	0 ms	780 KB
5	Accepted	0 ms	788 KB
6	Accepted	0 ms	756 KB
7	Accepted	4 ms	760 KB
8	Accepted	36 ms	792 KB
9	Accepted	924 ms	756 KB
10	Accepted	796 ms	760 KB