# Generating an optimal binary search tree (Cormen)

I'm reading Cormen et al., Introduction to Algorithms (3rd ed.) (PDF), section 15.4 on optimal binary search trees, but am having some trouble implementing the pseudocode for the optimal_bst function in Python.

Here is the example I'm trying to apply the optimal BST to:

Let us define e[i,j] as the expected cost of searching an optimal binary search tree containing the keys labeled from i to j. Ultimately, we wish to compute e[1, n], where n is the number of keys (5 in this example). The final recursive formulation is:

which should be implemented by the following pseudocode:

Notice that the pseudocode interchangeably uses 1- and 0-based indexing, whereas Python uses only the latter. As a consequence I'm having trouble implementing the pseudocode. Here is what I have so far:

import numpy as np

p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)

e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n+1):
for i in range(n-l+1):
j = i + l
e[i, j] = np.inf
w[i, j] = w[i, j-1] + p[j-1] + q[j]
for r in range(i, j+1):
t = e[i-1, r-1] + e[r, j] + w[i-1, j]
if t < e[i-1, j]:
e[i-1, j] = t
root[i-1, j] = r

print(w)
print(e)

However, if I run this the weights w get computed correctly, but the expected search values e remain 'stuck' at their initialized values:

[[ 0.05  0.3   0.45  0.55  0.7   1.  ]
[ 0.    0.1   0.25  0.35  0.5   0.8 ]
[ 0.    0.    0.05  0.15  0.3   0.6 ]
[ 0.    0.    0.    0.05  0.2   0.5 ]
[ 0.    0.    0.    0.    0.05  0.35]
[ 0.    0.    0.    0.    0.    0.1 ]]
[[ 0.05   inf   inf   inf   inf   inf]
[ 0.    0.1    inf   inf   inf   inf]
[ 0.    0.    0.05   inf   inf   inf]
[ 0.    0.    0.    0.05   inf   inf]
[ 0.    0.    0.    0.    0.05   inf]
[ 0.    0.    0.    0.    0.    0.1 ]]

What I expect is that e, w, and root be as follows:

I've been debugging this for a couple of hours by now and am still stuck. Can someone point out what is wrong with the Python code above?

It appears to me that you made a mistake in the indices. I couldn't make it work as expected but the following code should give you an indication where I was heading at (there is probably an off by one somewhere):

import numpy as np

p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)

def get2(m, i, j):
return m[i - 1, j - 1]

def set2(m, i, j, v):
m[i - 1, j - 1] = v

def get1(m, i):
return m[i - 1]

def set1(m, i, v):
m[i - 1] = v

e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n + 1):
for i in range(n - l + 2):
j = i + l - 1
set2(e, i, j, np.inf)
set2(w, i, j, get2(w, i, j - 1) + get1(p, j) + get1(q, j))
for r in range(i, j + 1):
t = get2(e, i, r - 1) + get2(e, r + 1, j) + get2(w, i, j)
if t < get2(e, i, j):
set2(e, i, j, t)
set2(root, i, j, r)

print(w)
print(e)

The result:

[[ 0.2   0.4   0.5   0.65  0.9   0.  ]
[ 0.    0.2   0.3   0.45  0.7   0.  ]
[ 0.    0.    0.1   0.25  0.5   0.  ]
[ 0.    0.    0.    0.15  0.4   0.  ]
[ 0.    0.    0.    0.    0.25  0.  ]
[ 0.5   0.7   0.8   0.95  0.    0.3 ]]
[[ 0.2   0.6   0.8   1.2   1.95  0.  ]
[ 0.    0.2   0.4   0.8   1.35  0.  ]
[ 0.    0.    0.1   0.35  0.85  0.  ]
[ 0.    0.    0.    0.15  0.55  0.  ]
[ 0.    0.    0.    0.    0.25  0.  ]
[ 0.7   1.2   1.5   2.    0.    0.3 ]]

In the end I used pandas' Series and DataFrame objects initialized with custom index and columns to coerce the arrays to have the same indexing as in the pseudocode. After that, the pseudocode can be almost copy-pasted:

import numpy as np
import pandas as pd

P = [0.15, 0.10, 0.05, 0.10, 0.20]
Q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(P)

p = pd.Series(P, index=range(1, n+1))
q = pd.Series(Q)

e = pd.DataFrame(np.diag(Q), index=range(1, n+2))
w = pd.DataFrame(np.diag(Q), index=range(1, n+2))
root = pd.DataFrame(np.zeros((n, n)), index=range(1, n+1), columns=range(1, n+1))

for l in range(1, n+1):
for i in range(1, n-l+2):
j = i+l-1
e.set_value(i, j, np.inf)
w.set_value(i, j, w.get_value(i, j-1) + p[j] + q[j])
for r in range(i, j+1):
t = e.get_value(i, r-1) + e.get_value(r+1, j) + w.get_value(i, j)
if t < e.get_value(i, j):
e.set_value(i, j, t)
root.set_value(i, j, r)

print(e)
print(w)
print(root)

which yields the expected results:

0     1     2     3     4     5
1  0.05  0.45  0.90  1.25  1.75  2.75
2  0.00  0.10  0.40  0.70  1.20  2.00
3  0.00  0.00  0.05  0.25  0.60  1.30
4  0.00  0.00  0.00  0.05  0.30  0.90
5  0.00  0.00  0.00  0.00  0.05  0.50
6  0.00  0.00  0.00  0.00  0.00  0.10
0    1     2     3     4     5
1  0.05  0.3  0.45  0.55  0.70  1.00
2  0.00  0.1  0.25  0.35  0.50  0.80
3  0.00  0.0  0.05  0.15  0.30  0.60
4  0.00  0.0  0.00  0.05  0.20  0.50
5  0.00  0.0  0.00  0.00  0.05  0.35
6  0.00  0.0  0.00  0.00  0.00  0.10
1    2    3    4    5
1  1.0  1.0  2.0  2.0  2.0
2  0.0  2.0  2.0  2.0  4.0
3  0.0  0.0  3.0  4.0  5.0
4  0.0  0.0  0.0  4.0  5.0
5  0.0  0.0  0.0  0.0  5.0

I would still be interested in a solution with Numpy arrays, though, as this seems more elegant to me.