Fast primality checking, factorizing and totient calculation in Competitive Programming*

*This post only works for n \leq10^{18} (which I guess is enough since it’s seldom to have problem which requires BigInteger to save the input)

I am going to explain how to solve phi-phi-phi which simultaneously covers in how to do primality checking and totient calculation. The problem is asking to calculate totient from a value. The value of variable k itself can be ignored since the number of repetition in calculating the totient again and again is not very large until reach \phi(1) .

First, we are going to do sieve until 10^{6} first. Since doing sieve is quite common already, you can check the reference in Sieve Methods : Prime, Divisor, Euler Phi etc. . After doing the sieve, we are going to do calculate phi in k times or until it reaches 1.

For a value n, we are going to factorize the n using primes we have from sieve. The number of primes ranging from 1 \leq p \leq 10^{6} is 78498. After we factorize the value using the primes, there are 3 cases for the leftover value (let it be r ) (if the value is not equal to 1):

  1. r = p^2 , where p is a prime, 10^6 \leq p \leq 10^9.
  2. r = p_1 * p_2 , where p_1 and p_2 are primes, 10^6 \leq p_1, p_2 \leq 10^9 and p_1 \neq p_2
  3. r = p , where p is a prime, 10^6 \leq p \leq 10^{18}

It is easy to check for the first test case, we can take  \lfloor \sqrt r \rfloor and check if the square is equal to r. O(log(r)) .

To check for the third case, we can use Miller–Rabin primality test. Using the deterministic variants which requires the use of first 12 primes, we can check whether it’s prime or not in O(k * log^3(r)) , in this case k = 12 .

The second test case requires us to find the value of p1 and p2, we can use Pollard’s rho algorithm. The problem is only in calculating the function g(x) . Since g(x) use a square function, where x < n \leq 10^{18} , it can cause overflow. A simple multiplication modular such as (x * x) % n won't be good enough since the value of x * x has pass the limit of 64-bit integer. The technique to overcome overflow is using Russian Multiplication. The complexity of Russian Multiplication is O(log(min(a, b)) . Russian Multiplication also can be applied in calculating power function which results in O(log(b)) time and O(1) space. Overall: O(n^{1/4})

The function for calculating totient is: \phi(n) = \prod_{i=1}^{k} (p_i^{{\alpha}_i} - p_i^{{\alpha}_i - 1}) , where n = \prod_{i=1}^{k} p_i^{{\alpha}_i} or using Inclusion–exclusion principle. (I use the second one)

The following code is my submission:

#include <cstdio>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <string>
#include <stack>
#include <queue>
#include <deque>
#include <vector>
#include <map>
#include <set>
#include <utility>
#include <algorithm>
#include <cmath>
#include <climits>
#ifdef DEBUG
    #include <ctime>
#endif
using namespace std;

// template

// abbreviations

typedef unsigned long long ull;
typedef long long ll;
typedef vector<int> vi;
typedef vector<vi> vvi;
typedef vector<ll> vl;
typedef vector<vl> vvl;
typedef vector<string> vs;
typedef pair<int, int> ii;
typedef vector<ii> vii;
typedef map<int, int> mii;
#define a first
#define b second
#define que queue
#define pque priority_queue
#define stk stack
#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define pu push
#define po pop
#define mp make_pair
#define it iterator
#define sz(var) ((int) var.size())
#define rep(it, n) for(int it = 0; it < n; ++it) #define dep(it, n) for(int it = n - 1; it >= 0; --it)
#define rep1(it, n) for(int it = 1; it <= n; ++it) #define dep1(it, n) for(int it = n; it > 0; --it)
#define loop(it, from, to) for(int it = (from); it <= (to); ++it)
#define iter(it, cont) for(__typeof((cont).begin()) it = (cont).begin(); it != (cont).end(); ++it)
#define riter(it, cont) for(__typeof((cont).rbegin()) it = (cont).rbegin(); it != (cont).rend(); ++it)
#define all(cont) (cont).begin(), (cont).end()
#define rng(cont, n) cont, cont + n
#define memclr(var) memset(var, 0, sizeof(var))

const int INF = INT_MAX;
const int NINF = INT_MIN;
const ll INF_LL = LLONG_MAX;
const ll NINF_LL = LLONG_MIN;
const double PI = acos(-1.0);
const int MOD = 1e9 + 7;

inline ll pos_m(ll a, ll c = MOD) { while (a < 0) { a += c; } return a; } inline ll add_m(ll a, ll b, ll c = MOD) { return (a + b) % c; } inline ll mul_m(ll a, ll b, ll c = MOD) { ll ret = 0; while (b) { if (b & 1) ret = add_m(ret, a, c); a = add_m(a, a, c); b >>= 1; } return ret; }
inline ll sub_m(ll a, ll b, ll c = MOD) { return pos_m((a - b) % c, c); }
inline ll pow_mod(ll a, ll b, ll c = MOD) { ll ret = 1; while (b) { if (b & 1) ret = mul_m(a, ret, c); a = mul_m(a, a, c); b >>= 1; } return ret; }

#ifdef DEBUG
    #define debug(fmt, args...) printf("Line %d, in %s\t: " fmt, __LINE__, __FUNCTION__, ##args)
    #define rep_rt() printf("[Run time: %.3fs]\n", ((double) clock()) / CLOCKS_PER_SEC)
#else
    #define debug(...)
#endif

// end of template

#define MAXSQRT3A (int) (1e6)
int is_prime[MAXSQRT3A + 1];
vi primes;

bool is_prime_fun(ll n) {
    int s = 0;
    ll d = n - 1;
    while ((d % 2) == 0) {
        s++;
        d >>= 1;
    }
    for (int i = 0; i < 12; ++i) {
        int prime = primes[i];
        ll fcond = pow_mod(prime, d, n);
        if (fcond == 1)
            continue;
        bool iscomp = true;
        for (int r = 0; r < s; ++r) {
            if (pow_mod(fcond, 1LL << r, n) == n - 1) {
                iscomp = false;
                break;
            }
        }
        if (iscomp)
            return false;
    }
    return true;
}

ll pr_g(ll x, ll n) {
    return add_m(pow_mod(x, 2, n), 1, n);
}

ll pr_get_prime_fact(ll val) {
    for (int i = 0; i < 12; ++i) {
        int prime = primes[i];
        ll x = prime, y = prime, d = 1;
        while (d == 1) {
            x = pr_g(x, val);
            y = pr_g(pr_g(y, val), val);
            d = __gcd(abs(x - y), val);
        }
        if (d != val)
            return d;
    }
    return -1;
}

ll tot_func(vl divs, ll lim) {
    ll ret = 0;
    for (int bm = 0; bm < (1 << sz(divs)); ++bm) {
        ll tot_div = 1;
        for (int i = 0; i < sz(divs); ++i) {
            if (bm & (1 << i))
                tot_div *= divs[i];
        }
        ll sans = lim / tot_div;
        if (__builtin_popcount(bm) % 2 == 0) {
            ret += sans;
        } else {
            ret -= sans;
        }
    }
    return ret;
}

ll sqrt_ll(ll val) {
    ll lb = 0, ub = (int) (1e9);
    while (lb < ub) {         ll m = (lb + ub) >> 1;
        ll sqm = m * m;
        if (sqm < val) {
            lb = m + 1;
        } else {
            ub = m;
        }
    }
    return lb;
}

int main() {
#ifdef DEBUG
    freopen("phi-phi-phi.in", "r", stdin);
#endif

    rep1(val, MAXSQRT3A) {
        is_prime[val] = true;
    }
    is_prime[0] = is_prime[1] = false;
    for (int i = 2; i * i <= MAXSQRT3A; ++i) {
        if (is_prime[i]) {
            for (int j = i; i * j <= MAXSQRT3A; ++j)
                is_prime[i * j] = false;
        }
    }
    rep1(val, MAXSQRT3A) {
        if (is_prime[val])
            primes.pub(val);
    }
    debug("%d\n", sz(primes));

    ll n, k;
    scanf("%lld %lld", &n, &k);
    while (k--) {
        ll lim = n;

        vl divs;
        iter(pp, primes) {
            int p = *pp;
            if (n % p == 0) {
                while (n % p == 0)
                    n /= p;
                divs.pub(p);
                debug("%d\n", p);
            }
        }
        if (n != 1) { // means n is a prime, sq prime or p1 * p2
            ll sqp = sqrt_ll(n);
            if (sqp * sqp == n)
                divs.pub(sqp);
            else if (is_prime_fun(n)) {
                divs.pub(n);
            } else {
                ll p1 = pr_get_prime_fact(n);
                ll p2 = n / p1;
                divs.pub(p1);
                divs.pub(p2);
            }
        }
        n = tot_func(divs, lim);
        if (n == 1)
            break;
    }
    printf("%lld\n", n);

#ifdef DEBUG
    rep_rt();
#endif
    return 0;
}

SPOJ: WILD – Wild West

This post is going to explain my solution for WILD – Wild West. My solution reach the first rank (when this thread is posted) when I use Fast IO. This solution is helped by Ahmad Zaky (user: azaky).

If we analyze the problem carefully, the problem is actually asking volume of cube with length of m subtract with volume of the cuboids union. When I say union of cuboids, imagine multiple cuboids merge together (not the total volume sum of each cuboid). The problem is already abstracted to be “What is the volume of the merge cuboid?”. The constraint problem is n, m\leq100000, which means we need to solve is less than O(n^2).

Imagine the problem 3 dimensional Cartesian system, skill A represents the X-axis, B represents the Y-axis and C represents the Z-axis.
It easy to state that, if a cuboid spans to a (X-axis), b (Y-axis) in c (Z-axis), then definitely it will spans to a (X-axis), b (Y-axis) in d \leq z (Z-axis). It means cuboid with higher c effects cuboid below.

If we try to look at each c in (Z-axis) and we calculate the sum of space that is spanned by any cuboid (X,Y-axis), it is the volume of merge cuboid.
Now, we have abstracted the problem again. We need to calculate space that spanned by any cuboid for Z-axis in 1 \leq c \leq m.

Assume there is 2 cuboid which spans to same some Z-axis value, let the first one is (a, b) and another is (c, d) (in terms of (X-axis, Y-axis)). if b\leq d, the second cube will covers the first cube if a\leq c and not if otherwise.

Let us a create a one dimensional array with length of m which represents the x axis, and each element in the array represent the length of y axis that has been spanned (or vice versa x to y and y to x).
So, let us use the 1st testcase an example:

3 10
2 8 5
6 3 5
1 3 9

There are 3 cuboids, if we sort the Z-axis in smaller way, it might look something like this: (the second and third element order don’t matter)

1 3 9
2 8 5
6 3 5

So, we start with c=10 first. Since there are no cuboid that spans up to this Z-axis, all values in the array are 0.

0 0 0 0 0 0 0 0 0 0

Now, we reach to c=9, the this is the first cuboid we meet. This updates the array to:

3 0 0 0 0 0 0 0 0 0

The volume of cuboid has spanned to value of 3. This happens until c=6.
When we get to c=5, the second element cause the array become:

8 8 0 0 0 0 0 0 0 0

The third cuboid updates the arrays to:

8 8 3 3 3 3 0 0 0 0

Note that the first and second value in the array doesn’t change because 8 > 3, it means there is a cuboid that spans y to 8 in x-axis of value 1 and 2 which covers a cuboid that spans less than that.
Now we calculate the value of the volume that spanned by the cuboids, this happens until c=1. The total volume that spanned by union cuboids is 3*4+28*5=152 . The answer is 1000-152=848 as expected.

The problem is now easy, try to have a data structure which has 2 following queries:
1. Do an update value for specified range and applied only when the previous value is less that that.
2. Can get the sum of total value

This can be solved with Segment Tree, but I am not going to explain that solution. Because I haven’t figured it out (though Zaky already has). My solution is pretty much the same as Subsequence Weighting solution. Using map<int, int>, you can do the update and sum in O(log m).

Here is my code: http://ideone.com/NUFn6z

Let me know if you need more explanation in the solution. 😀

Subsequence Weighting

This is my first post on this blog. To get things started, I will share an interesting problem I found a while back on HackerRank.

The solution is pretty obvious for people who are familiar with segment tree or BIT. The easiest implementation to solve this would be BIT, which takes O(n log n) in this case. The solution is suffice since n\leq150000.

The solution which I will explain is not about former solution but using map function from C++ only. The first time I read about the problem, I still hasn’t grasped knowledge about those advanced data structure. I was thinking how I could project the problem to increasing value of stack.

There are few pruning that can be used if the stack consist of value and weight (increasing in value). If the value keeps increasing, the weight in element must be at least larger than the element below (smaller value) to achieve optimum result. If the property is maintain every insertion of element, it will guarantee that the weight in the top of stack will be the maximum. The problem is sometimes the element may be inserted in the middle of the stack.

The only data structure that provides that feature is map. The lookup for previous element (lower value) takes O(log n). If above element has (larger value) less weight than current element, the above element needs to be removed to maintain the property. The deletion takes O(log n). The total complexity result in O(n log n)

#include <cstdio>
#include <stack>
#include <queue>
#include <vector>
#include <map>
#include <utility>
#include <string>
#include <algorithm>
#include <cmath>
#include <iostream>
#include <fstream>
#include <sstream>
#include <set>
using namespace std;
// template

// abbreviations
#define vi vector <int>
#define ii pair <int, int>
#define a first
#define b second
#define vii vector <ii>
#define mii map <int, int>
#define mll map <long long, long long>
#define que queue
#define pque priority_queue
#define stk stack
#define lsone(value) (value)&(-value)

typedef unsigned long long ull;
typedef long long ll;

int main() {
    // freopen("subsequence-weighting.in", "r", stdin);
    int nt;
    scanf("%d", &nt);
    while (nt--) {
        int n;
        scanf("%d", &n);
        vector<pair<ull, ull> > seqs(n);
        for (int it = 0; it < n; ++it)
            scanf("%llu", &seqs[it].first);
        for (int it = 0; it < n; ++it)
            scanf("%llu", &seqs[it].second);


        map<ull, ull> maxWeight;

        vector<pair<ull, ull> >::iterator curr = seqs.end();
        for (advance(curr, -1); curr >= seqs.begin(); --curr) {
            ull key = curr->first;
            ull val = curr->second;
            
            map<ull, ull>::iterator higher = maxWeight.upper_bound(key);
            if (higher != maxWeight.end()) // there is no higher key
                val += higher->second;

            if ((maxWeight.count(key)) and (maxWeight[key] >= val))
                continue;

            maxWeight[key] = val;

            map<ull, ull>::iterator lower;
            while ((lower = maxWeight.lower_bound(key)) != maxWeight.begin()) {
                advance(lower, -1);
                if (lower->second < val)
                    maxWeight.erase(lower);
                else
                    break;
            }
        }
        printf("%llu\n", maxWeight.begin()->second);
    }
}