longest increasing subsequence

tl;dr

This post discusses the $O(n^2)$, $O(n log(n))$ complexity methods to find the length of longest increasing subsequence (LIS), and the way to recover the subsequence.

the naive dynamic programming

… which has $O(n^2)$ time complexity. dp[i] is the max length of LIS ending with nums[i]. Thus we have dp[i] = max(dp[j], 1) for 0 <= j < i and nums[j] < nums[i].

import java.util.*;
public class Solution {
    public int lengthOfLIS(int[] nums) {
        // time O(n^2)
        // space O(n)
        
        int[] dp = new int[nums.length];
        // dp[i] is the max length of LIS ending with nums[i]
        Arrays.fill(dp, 1); // attention here!
        int res = 0;
        
        for (int i = 0; i < nums.length; ++i) {
            for (int j = 0; j < i; ++j) {
                if (nums[j] < nums[i]) {
                    dp[i] = Math.max(dp[i], dp[j] + 1);
                }
            }
            res = Math.max(res, dp[i]);
        }
        return res;
    }
}

the optimized dynamic programming

… which has $O(n log(n))$ time complexity because we use the binary search somewhere. dp[i] is the smallest possible ending value in all the increasing subsequence with length i + 1. For each number in nums, we insert the number into the position in dp such that dp is strictly increasing. GeeksforGeeks has one of the best explanations.

import java.util.*;
public class Solution {
    public int lengthOfLIS(int[] nums) {
        // time O(n log(n))
        // space O(n)
        int[] dp = new int[nums.length];
        // dp[i] is the smallest possible ending value in all increasing subsequence with length i + 1
        int res = 0;
        
        for (int i = 0; i < nums.length; ++i) {
            int num = nums[i];
            int pos = Arrays.binarySearch(dp, 0, res, num); // (array, start, end, key) (end exclusive)
            if (pos < 0) {
                // not found, -> insertion point
                pos = - (pos + 1);
            }
            dp[pos] = num;
            
            if (pos == res) {
                ++res;
            }
        }
        return res;
    }
}

how to reconstruct the subsequence?

Many many posts provide the above optimized way, but do NOT discuss the way to reconstruct the subsequence clearly. For example, one SO post discusses the reconstruction, but doesn’t provide the code!

We should pay attention that the dp in the end is NOT the LIS.

So one way is to maintain an array prev. prev[i] indicates the “parent / previous index in nums” for the value nums[i]. Moreover, for simplicity, we add another array dpIdx to maintain the original index in nums for each value in dp (we could actually replace dp with dpIdx to save space, but we need to modify the binary search function). In each iteration, we use binary search to find the insertion point pos to dp. If pos is larger than 0, it means that the number to be inserted is an extension to the subsequence ending with dp[pos - 1], thus prev[i] = dpIdx[pos - 1]. Otherwise, pos == 0, then this number could potentially be the start of a new LIS, thus prev[i] = -1 to indicate that the number has no previous element. After we find out the length, we can construct the LIS backwards.

import java.util.*;
public class Solution {
        public int lengthOfLIS(int[] nums) {
        // time O(n log(n))
        // space O(n)
        
        int[] dp = new int[nums.length];
        int[] dpIdx = new int[nums.length];
        int[] prev = new int[nums.length];
        int res = 0;
        
        for (int i = 0; i < nums.length; ++i) {
            int num = nums[i];
            int pos = Arrays.binarySearch(dp, 0, res, num); // (array, start, end, key) (end exclusive)
            if (pos < 0) {
                // not found, -> insertion point
                pos = - (pos + 1);
            }
            
            dp[pos] = num;
            dpIdx[pos] = i;
            if (pos > 0) {
                prev[i] = dpIdx[pos - 1];
            } else {
                prev[i] = -1;
            }
            
            if (pos == res) {
                ++res;
            }
        }

        // reconstruct the the LIS
        int[] lis = new int[res];
        for (int i = dpIdx[res - 1], j = lis.length - 1; i >= 0 && j >= 0; i = prev[i], j -= 1) {
            lis[j] = nums[i];
        }
        
        System.out.printf("nums:  %s\n", Arrays.toString(nums));
        System.out.printf("dp:    %s\n", Arrays.toString(dp));
        System.out.printf("dpIdx: %s\n", Arrays.toString(dpIdx));
        System.out.printf("prev:  %s\n", Arrays.toString(prev));
        System.out.printf("LIS:   %s\n", Arrays.toString(lis));
        
        return res;
    }

    public static void main(String[] args) {
        Solution s = new Solution();
        int res = s.lengthOfLIS(new int[]{0, 2, 6, 1, 7, 4, 8, 5, 9, 7});
    }
}

how about longest non-decreasing subsequence?

The binary search usually returns the left-most / lower-bound index for the target. But if we want the subsequence be non-decreasing, we should insert the num into the first position larger than the target! Acutally std::upper_bound in C++ is what I mean.

For my above code in Java, pay attention to:

            int pos = Arrays.binarySearch(dp, 0, res, num); // (array, start, end, key) (end exclusive)
            if (pos < 0) {
                // not found, -> insertion point
                pos = - (pos + 1);
            }

Change it to:

            int pos = Arrays.binarySearch(dp, 0, res, num); // (array, start, end, key) (end exclusive)
            if (pos < 0) {
                // not found, -> insertion point
                pos = - (pos + 1);
            }
            while (pos < res && dp[pos] == num) {
                // shift right
                pos += 1;
            }

Or, if we are only dealing with integers, we could search the num + 1 instead:

            // pay attention: search num + 1, not num
            int pos = Arrays.binarySearch(dp, 0, res, num + 1); // (array, start, end, key) (end exclusive)
            if (pos < 0) {
                // not found, -> insertion point
                pos = - (pos + 1);
            }

try it

On LeetCode #300.