张量相乘的最小开销问题
题目
by 吴世光
描述
规定一种张量乘法, 对于 \(K(K \geq 2)\) 维张量, 前 \(K - 2\) 维按张量广播乘法, 后两维按矩阵乘法. 现有 \(n\) 个 \(K\) 维张量按此相乘, 可以为其加上若干对括号, 求最少标量乘法次数.
提示: 张量个数小于 \(2048\), 维数小于 \(32\), 每个维度的大小小于 \(1000\).
IO格式
输入:
n K
第1个张量第1维长度...
...
输出:
最少标量乘法次数
样例
输入:
3 3
1 1 2
1 2 3
10 3 4
输出:
126
解答
思路
这是一道典型的区间 DP.
对于动态规划问题, 我们的解题思路是: 设计状态函数、求状态函数初值、写出转移方程、化简转移方程、遍历求解.
状态函数
设 dp(i, j)
为从第 i
个张量一直连乘到第 j
个张量的最少标量乘法次数. 我们不难得出以下两个结论:
- 我们的所求为
dp(0, n-1)
; - 我们已经得知的初值为
dp(i, i) = 0 for i = 0 ... n-1
.
转移方程
对于一个添加括号的行为, 其结果是将一部分的乘法运算独立, 因此其最简单的就是将一连串的乘法分为两段 (构造最优子结构); 因此, 转移方程可以写为如下:
dp(i, j) = min(
dp(i, k) + dp(k+1, j) + cost(i, j, k)
for k = i ... j-1
)
这里, cost(i, j, k)
是将两部分的乘法结果合并所产生的花销, 关于它究竟是什么暂且按下不谈.
求解
现在我们来考虑一下 i
和 j
的取值.
考虑如下的一张 dp 表, 我们已经填入了初值; 表中最右上角即为我们的目标解:
\[\begin{bmatrix} 0 & \infty & \infty & \cdots & \infty \\ & 0 & \infty & \ddots & \vdots \\ & & 0 & \ddots & \vdots \\ & & & \ddots & \infty \\ & & & & 0 \end{bmatrix}\]根据状态方程, 对于一个待求的 dp(i, j)
, 我们需要知道在第 i
行和第 j
列上介于 dp(i, j)
和 dp(i, i)
、dp(j, j)
的值:
因此, 我们将会遵循以下的遍历顺序:
- 以
s = 0
为对角线,s
可以赋予意义为 dp 表的斜行坐标, 不过它在问题中的真实意义为我们对问题拆分的小问题的步长. for s = 1 ... n-1
, 我们有j = i + s for i = 0 ... n-1-s
.
这个遍历为 \(O(n^2)\) 的遍历.
合并
最后, 我们来看一下 cost(i, j, k)
. 对于两段已经乘好的张量, 我们将其合并, 我们可以得出如下两部分花销:
- 一是前面 \(K - 2\) 维的结果恒定, 并且张量乘法遵循广播, 所以由于维度导致的标量乘法次数系数是不变的:
cost1(i, j) = prod(max(Tensor[i ... j]D(k)) for k = 1 ... K-2)
. - 二是在 \(l \times m\) 和 \(m \times n\) 矩阵乘的过程中, 这一部分的花销形如 \(lmn\). \(m\) 是变化的, 但是 \(l\) 和 \(n\) 是不变的:
cost2(i, j) = Tensor[i]D(K-1) * Tensor[j]D(K)
,cost3(k) = Tensor[k]D(K)
.
故, cost(i, j, k) = cost1(i, j) * cost2(i, j) * cost3(k)
.
代码
初步思路
最初为了计算 cost(i, j, k)
, 我是希望将计算过程全部记录下来的, 但是这不仅费空间, 而且不是一个好的算法:
#include <cstdio>
#include <climits>
using namespace std;
#define prf printf
#define scf scanf
using ui = unsigned int;
using us = unsigned short;
ui n, K, Ks;
struct Tensor {
Tensor() : dims(new us[Ks]), mat(new us[2]) {}
~Tensor() {
delete[] dims;
delete[] mat;
}
us *dims, *mat;
void in() {
for (ui i = 0; i < Ks; i++) scf("%hu", &dims[i]);
scf("%hu%hu", &mat[0], &mat[1]);
}
};
struct Cost {
Cost() : val(UINT_MAX) {}
Cost(ui _val) : val(_val) {}
ui val;
operator ui() { return val; }
bool operator<(Cost& other) { return val < other.val; }
};
struct Dp {
Dp(ui _n) : n(_n), __arr(new Cost[_n * _n]), __res(new Tensor[_n * _n]) {}
~Dp() {
delete[] __arr;
delete[] __res;
}
ui n;
Cost* __arr;
Tensor* __res;
Cost& operator()(ui i, ui j) { return __arr[i * n + j]; }
Tensor& result(ui i, ui j) { return __res[i * n + j]; }
};
inline ui cnt(Tensor& a, Tensor& b, Tensor& res) {
ui ret = 1;
for (ui i = 0; i < Ks; i++)
ret *= res.dims[i] = a.dims[i] == 1 ? b.dims[i] : a.dims[i];
ret *= (res.mat[0] = a.mat[0]) * b.mat[0] * (res.mat[1] = b.mat[1]);
return ret;
}
int main() {
scf("%u%u", &n, &K);
Ks = K - 2;
Dp dp(n);
for (ui i = 0; i < n; i++) {
dp.result(i, i).in();
dp(i, i) = 0;
}
ui rend = n - 1;
for (ui r = 1; r <= rend; r++) {
ui iend = n - 1 - r;
for (ui i = 0; i <= iend; i++) {
ui j = i + r;
Cost& cur = dp(i, j);
Tensor& curRes = dp.result(i, j);
for (ui k = i; k < j; k++) {
Cost cost = dp(i, k) + dp(k + 1, j) +
cnt(dp.result(i, k), dp.result(k + 1, j), curRes);
if (cost < dp(i, j)) cur = cost;
}
}
}
prf("%u", ui(dp(0, n - 1)));
return 0;
}
最终提交
#include <cstdio>
#include <climits>
using namespace std;
#define prf printf
#define scf scanf
using ui = unsigned int;
using us = unsigned short;
ui n, K, Ks;
struct Tensor {
Tensor() : dims(new us[Ks]), mat(new us[2]) {}
~Tensor() {
delete[] dims;
delete[] mat;
}
us *dims, *mat;
void in() {
for (ui i = 0; i < Ks; i++) scf("%hu", &dims[i]);
scf("%hu%hu", &mat[0], &mat[1]);
}
};
struct Cost {
Cost() : val(UINT_MAX) {}
Cost(ui _val) : val(_val) {}
ui val;
operator ui() { return val; }
bool operator<(Cost& other) { return val < other.val; }
};
struct Dp {
Dp(ui _n) : n(_n), __arr(new Cost[_n * _n]), __ten(new Tensor[_n]) {}
~Dp() {
delete[] __arr;
delete[] __ten;
}
ui n;
Cost* __arr;
Tensor* __ten;
Cost& operator()(ui i, ui j) { return __arr[i * n + j]; }
Tensor& result(ui i) { return __ten[i]; }
};
inline ui cnt(Dp& dp, ui i, ui j) {
ui ret = 1;
for (ui kk = 0; kk < Ks; kk++) {
for (ui ii = i; ii <= j; ii++) {
ui dim = dp.result(ii).dims[kk];
if (dim != 1) {
ret *= dim;
break;
}
}
}
return ret;
}
inline ui cnt(Dp& dp, ui i, ui j, ui k) {
return dp.result(i).mat[0] * dp.result(k).mat[1] * dp.result(j).mat[1];
}
int main() {
scf("%u%u", &n, &K);
Ks = K - 2;
Dp dp(n);
for (ui i = 0; i < n; i++) {
dp.result(i).in();
dp(i, i) = 0;
}
ui rend = n - 1;
for (ui r = 1; r <= rend; r++) {
ui iend = n - 1 - r;
for (ui i = 0; i <= iend; i++) {
ui j = i + r;
Cost& cur = dp(i, j);
ui curCntKs = cnt(dp, i, j);
for (ui k = i; k < j; k++) {
Cost cost =
dp(i, k) + dp(k + 1, j) + curCntKs * cnt(dp, i, j, k);
if (cost < cur) cur = cost;
}
}
}
prf("%u", ui(dp(0, n - 1)));
return 0;
}
时空消耗
1 Accepted 0 ms 832 KB
2 Accepted 0 ms 832 KB
3 Accepted 60 ms 1716 KB
4 Accepted 0 ms 828 KB
5 Accepted 60 ms 1716 KB
6 Accepted 0 ms 836 KB
7 Accepted 64 ms 1712 KB
8 Accepted 164 ms 1712 KB
9 Accepted 2832 ms 13564 KB
10 Accepted 2844 ms 13564 KB