Bitmask DP & คำอธิบายวิธีทำพร้อม code สำหรับข้อ toi20_bit_string

สรุปโจทย์

เรามี Binary String (สตริงที่ประกอบด้วย ‘1’ และ ‘0’ เท่านั้น) ความยาว NN ซึ่งหากเราพิจารณา String ความยาว NN ทุกรูปแบบ จะพบว่ามี 2N2^N รูปแบบ โดยแต่ละรูปแบบ จะมีค่าน้ำหนัก CiC_i

การลดทอนคุณภาพ จะเป็นการเปลี่ยน ‘1’ เป็น ‘0’ ซึ่งสามารถทำได้ 2 วิธี

  1. เปลี่ยน ‘1’ เป็น ‘0’
  2. เปลี่ยน ‘11’ (อักขระ ‘1’ สองตัวติดกัน) เป็น ‘00’ (อักขระ ‘0’ สองตัวติดกัน)

โดยเมื่อได้ Binary String ใหม่ที่เกิดจากการเปลี่ยน ‘1’ เป็น ‘0’ เราจะบวกค่าน้ำหนักของ Binary String ในขณะนั้น ไปกับ "ค่าลดทอนคุณภาพ"
แล้ว Operation นี้ จะจบลงเมื่อ String กลายเป็น “00000…0” ซึ่งเราจะต้องการให้ ค่าลดทอนคุณภาพ ค่ามากที่สุด

ตัวอย่างการลดทอนคุณภาพ

N=3N = 3
001=9001 = 9
000=0000 = 0
100=9100 = 9
010=1010 = 1
101=1101 = 1
110=2110 = 2
111=3111 = 3
011=1011 = 1
Binary String ที่กำหนดให้: 111111

ขั้นตอน:
ค่าลดทอนคุณภาพ เริ่มต้นที่ C111C_{111} = 3

  1. 111 -> 011 : ansans เพิ่มไป C011C_{011} = 1 -> 4
  2. 011 -> 001 : ansans เพิ่มไป C001C_{001} = 9 -> 13

ดังนั้น คำตอบของ Binary String 111111 คือ 13

โดยโจทย์จะกำหนด Binary String มาให้ QQ อัน และให้ตอบให้ครบ

สิ่งที่ต้องทำ

เขียนโค้ดเพื่อหา ค่าลดทอนคุณภาพ ที่มีค่ามากที่สุด สำหรับ Binary String ทั้ง QQ อัน

ขอบเขตข้อมูล

2N202 \leq N \leq 20
500,000Ci1,000,000-500,000 \leq C_i \leq 1,000,000

Prerequisites

Bitmask DP

สำหรับข้อนี้ เราจะใช้สิ่งที่เรียกว่า Bitmask DP ซึ่งเป็นรูปแบบการทำ Dynamic Programming ที่จะค่อนข้างพิสดารเล็กน้อย

โดย Bitmask DP นั้น ลักษณะจะเป็นการเล่นกับ เลขฐานสอง ซึ่งแทนที่จะเก็บ DP ในแต่ละ Index เราจะเก็บ DP ในแต่ละ Mask (เลขฐานสอง) แทน นั่นคือ เราจะมีทั้งหมด 2N2^N Mask (เราต้องการเล่นกับทุกรูปแบบของเลขฐานสองที่มี NN หลัก)

โดยโจทย์ส่วนใหญ่นั้น เราจะเก็บ State DP ในลักษณะของ

ลักษณะการทำ Bitmask DP ทั่วไปคือ

  1. loop ตามทุกๆ maskmask ที่เป็นไปได้ O(2N)O(2^N)
  2. ลูปตามแต่ละตัวใน maskmask นั้น แล้วเช็กว่า หากสับ bit นั้นออก (1>0)(1 > 0) แล้ว dpdp ของ state นั้น เคยทำมาแล้วหรือไม่ แล้วทำได้มั้ย ดังรูปด้านล่าง (เปลี่ยนทุก bit ที่เป็น 1 ให้เป็น 0 แล้วเช็ก)
  3. หากทำได้ ก็เอาเข้า cost function เพื่อนำมาใส่ใน dp[mask]dp[mask] ปัจจุบัน

โดยในโจทย์ข้อนี้ เราจะกำหนดลักษณะ DP เป็นรูปแบบที่ 2 นั่นคือ dp[mask]=dp[mask] = ค่าที่ต่ำที่สุดที่จะนำเรามายัง state ที่ maskmask

วิธีทำ

ก่อนอื่น เราก็จะต้องรับ input ว่า สำหรับแต่ละ maskmask ค่าลดทอนคุณภาพของมันจะเป็นเท่าไหร่ เราก็จะรับ input มาเป็น Binary String แล้วก็เข้า function binarybinary ที่เขียนไว้ เพื่อแปลง Binary String เป็น เลขฐานสิบ

หลักจากนั้น เราก็แค่นำไอเดียของ Bitmask DP ด้านบน ลงมาใช้ โดยเราจะเก็บว่า คำตอบที่มากที่สุดสำหรับ dp[mask]dp[mask] เป็นเท่าไหร่ โดยเราจะ Loop ไปทุกๆ maskmask (นั่นคือ วน 2N2^N รอบ) โดยสำหรับแต่ละ maskmask จะคำนวณดังนี้:

Loop ครั้งที่ 1 (สับออก 1 bit)

  1. เช็กแต่ละ index ว่า maskmask ปัจจุบันในตำแหน่งที่ ii เป็น 1 หรือ 0
    1.1 หากเป็น 0 ก็ข้าม
    1.2 หากเป็น 1 ให้เอา bit นั้นเป็น 0 (ตั้งชื่อว่า nmnm ย่อจาก newmasknew mask) แล้วก็
    เก็บ dp[mask]=max(dp[mask],dp[nm]+a[mask])dp[mask] = max(dp[mask], dp[nm] + a[mask]) ซึ่งคือ cost function ของเรา

Loop ครั้งที่ 2 (สับออก 2 bit)

  1. เช็กแต่ละ index ว่า maskmask ปัจจุบันในตำแหน่งที่ ii และ i+1i+1 เป็น 1 หรือ 0
    1.1 หากสักอันเป็น 0 ก็ข้าม
    1.2 หากเป็น 1 ทั้งคู่ ให้เอาทั้งสอง bit นั้นเป็น 0 (ตั้งชื่อว่า nmnm ย่อจาก newmasknew mask) แล้วก็เก็บ dp[mask]=max(dp[mask],dp[nm]+a[mask])dp[mask] = max(dp[mask], dp[nm] + a[mask]) ซึ่งคือ cost function ของเรา

แล้วเมื่อรับคำถามมา เราก็แค่ส่ง Output ไปเป็น dp[input]dp[input] สำหรับแต่ละอินพุตได้เลย

Code

#include <bits/stdc++.h>

using namespace std;

const long long mod = 1e9 + 7;
const long long inf = 1e18;

long long binary(string s){
    long long ans = 0;
    reverse(all(s));
    for (long long i = 0; i < s.length(); i++) {
        ans += (s[i] - '0') * (1 << i);
    }
    return ans;
}

int32_t main(){
    cin.tie(NULL)->sync_with_stdio(false);
    long long n, q; cin >> n >> q;
    vector <long long> a(1 << n);
    for (long long i = 0; i < (1 << n); i++) {
        string s; cin >> s;
        long long num; cin >> num;
        a[binary(s)] = num;
    }
    vector <long long> dp(1 << n, -inf); dp[0] = 0;
    for (long long mask = 1; mask < (1 << n); mask++) {
        // swap 1 bit
        for (long long i = 0; i < n; i++) {
            if ((mask & (1 << i)) == 0) continue;
            long long nm = mask ^ (1 << i);
            dp[mask] = max(dp[mask], dp[nm] + a[mask]);
        }
        // swap 2 bits
        for (long long i = 0; i < n - 1; i++) {
            if ((mask & (1 << i)) == 0 || (mask & (1 << (i + 1))) == 0) continue;
            long long nm = mask ^ (1 << i) ^ (1 << (i + 1));
            dp[mask] = max(dp[mask], dp[nm] + a[mask]);
        }
    }
    while (q--) {
        string s; cin >> s;
        cout << dp[binary(s)] << "\n";
    }
}

Total Time Complexity: O(2NN)O(2^N \cdot N)

หากมีข้อสงสัย comment ไว้ใต้ post ได้เลยนะครับ 🙇‍♂️🙇‍♂️
ศึกษาโจทย์เพิ่มเติมได้ที่ Fast X Fourier