Page List

Search on the blog

2011年5月27日金曜日

ビット演算魔術、部分集合を高速に出すの巻

 以下のビット演算が何を意味するのか知ってる人は読まなくていいです。
--y &= x

 次の問題を解いているときに出会いました。
[SRM474 SquaresCovering]
2次元平面上に分散するN個の点を正方形で多いたい。正方形はK種類あり、それぞれ大きさと使用するためのコストが異なる。すべての点を覆うには最低いくらコストが必要か?ただし、正方形は無限個あり、同じ正方形を複数回使用しても構わない。

N<=16, K<=50
0 < xi < 10^9, i=1,..,n
0 < yi < 10^9, i=1,..,n

 ぱっと見普通のビットDPで解けそう。が、2秒縛りがキツイ。straightforwardに解くとO(2^{2N})なので、なんとかしたい。
何とかできなかった。。
以下rngさんの回答。やっぱ数学オリンピック金メダリストは違う。。と思ったけど、アリーナからコピペ出来なかったので、私のヘタレソースで我慢してください。(横にはみ出していて、見にくい場合はview plainで見てください。)
  1. const long long int INF = 999999999999999LL;  
  2. long long int dp[1<<16];  
  3.   
  4. class SquaresCovering {  
  5. public:  
  6.   int minCost(vector<int> x, vector<int> y, vector<int> cost, vector<int> sides) {  
  7.       int N = x.size();  
  8.       int K = sides.size();  
  9.   
  10.       REP(mask, 1<<N) {  
  11.           LL xmin = INF, xmax = -1;  
  12.           LL ymin = INF, ymax = -1;  
  13.   
  14.           REP(i, N) {  
  15.               if (mask >> i & 1) {  
  16.                   xmin = MIN(x[i], xmin);  
  17.                   xmax = MAX(x[i], xmax);  
  18.                   ymin = MIN(y[i], ymin);  
  19.                   ymax = MAX(y[i], ymax);  
  20.               }  
  21.           }  
  22.   
  23.           LL val = INF;  
  24.           int diff = max(xmax-xmin, ymax-ymin);  
  25.           REP(i, K)  
  26.               if (diff <= sides[i])  
  27.                   val = MIN(val, cost[i]);  
  28.           dp[mask] = val;  
  29.       }  
  30.   
  31.       REP(x, 1<<N) {  
  32.           for (int y = x; y > 0; --y &= x)  
  33.               dp[x] = min(dp[x], dp[y]+dp[x^y]);  
  34.       }  
  35.   
  36.       return (int)dp[(1<<N)-1];  
  37.   }  
  38. };  
 --y & xは、集合xの部分集合をすべて列挙することができます。これはかなりカッコいいです。ピンと来ない人は次のソースを実行してみましょう。
  1. #include <bitset>  
  2. int main() {  
  3.   int x = 202;  
  4.   
  5.   for (int y = x; y > 0; --y &= x)  
  6.       cout << bitset<8>(y) << endl;  
  7. }  
上の実行結果を見れば何がしたいか分かるはずです。

 最後に、比較してみましょう。SRMの問題を愚直な2重ループでやると、4*10^9くらいかかります。
この方法を使用すれば、2^16 * 2^8 = 2^24 ~16*10^6くらいには収まるかなという予想です。
  1. int main() {  
  2.   LL x = 0;  
  3.   REP(mask1, 1<<16)  
  4.       REP(mask2, 1<<16)  
  5.           x++;  
  6.   
  7.   LL y = 0;  
  8.   REP(mask1, 1<<16)  
  9.       for (int mask2 = mask1; mask2 > 0; --mask2 &= mask1)  
  10.           y++;  
  11.   
  12.   cout << x << endl;  
  13.   cout << y << endl;  
  14.   
  15.   return 0;  
  16. }  
上のプログラムを実行した場合、出力される数値は、
4294967296
42981185
です。まあこんなもんでしょうか。100倍くらい速くなります。



0 件のコメント:

コメントを投稿