In [1]:
import numpy as np
import time
import seaborn as sns
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
import requests
from bs4 import BeautifulSoup
from skimage.morphology import label
from scipy.ndimage import measurements
from scipy.signal import convolve2d
from z3 import *
from IPython.display import Markdown, display,Image
from matplotlib.patches import Rectangle
In [2]:
url='https://www.janestreet.com/puzzles/current-puzzle/'
res = requests.get(url)
soup = BeautifulSoup(res.content, 'html.parser')
y =[text.replace(">","greater than") for text in soup.body.stripped_strings]
display(Markdown("# " + y[13]+ "\n ---- \n"))
display(Image('https://www.janestreet.com/puzzles/hooks-10.png',width=600))
display(Markdown("\n".join(y[14:22])))

Hooks 10¶


The grid above can be partitioned into 9 L-shaped “hooks”. The largest is 9-by-9 (contains 17 squares), the next largest is 8-by-8 (contains 15 squares), and so on. The smallest hook is just a single square. Find where the hooks are located, and place nine 9’s in one of the hooks, eight 8’s in another, seven 7’s in another, and so on. The filled squares must form a connected region. (Squares are “connected” if they are orthogonally adjacent.) Furthermore, every 2-by-2 region must contain at least one unfilled square. The clues in the grid are placed in cells that are not filled in the completed grid. A number in the grid represents the sum of all values in orthogonally adjacent cells in the completed grid. The answer to this puzzle is the product of the areas of the connected groups of empty squares in the completed grid.

In [3]:
cons_small = {(0,0):0, (2,0): 8, (4,0): 10, (1,2):9, (3,2):15, (1,4):7, (3,4): 12}
cons_big = {(0,1):18, (0,6): 7, (1,4): 12, (2,2): 9, (2,7):31, (4,1):5, (4,3):11, (4,5): 22, (4,7): 22, (6,1):9, (6,6):19,(7,4):14, (8,2):22, (8,7):15}
In [4]:
def grid_print(solved,matrix, cons):
    """ Prints the solved grid """
    fig,ax = plt.subplots(1,1,figsize=(4,4))
    x = np.array((solved*matrix).astype('int').astype('str'))
    x[x=="0"] ="-"
    for (i,j), v in cons.items():
        x[i,j] = -v 
    ax = sns.heatmap(matrix,annot=x,cbar=False,cmap="tab10",fmt="",linewidths=0.25, annot_kws={"fontsize":12})
    
    plt.tight_layout()
    plt.show()
    
    
def neigh(x,y,N):
    """Returns a list of neighboring cell positions for the cell at (x,y) in a N x N grid """
    return [(i + x, j + y) for i,j in [(-1,0),(0,-1),(1,0),(0,1)] if i + x >= 0 and i + x < N and j + y >= 0 and j + y < N]

def areas(grid):
    """Returns the numerical answer from the solved grid"""
    labels, num = measurements.label(np.logical_not(grid!=0))
    areas = measurements.sum(np.logical_not(grid!=0), labels, index=range(1, num+1))
    print(areas)
    return np.prod(areas)

    
def test_solution(result_x, result_y, cons):
    """ We need to modify this to check that all the sums are met correctly"""
    soln = result_x*result_y
    n = len(result_x[0])
    for (i,j), v in cons.items(): # for i in 1...n, return false if the solution fails
        # Sum the neibhors of this value.
        neighboring_cells_coord = neigh(i, j, n)
        if sum(soln[k,l] for k,l in neighboring_cells_coord) != v:
            return False
    return True
In [5]:
start = time.time()
N = 9
cons = cons_big
#set up the solver and variables.
s = Tactic("qffd").solver()

# H = hooks
H = np.array(IntVector("h", N ** 2), dtype=object).reshape((N,N))

row_fix =IntVector("r", N)
col_fix =IntVector("c", N)

s += [And(e > 0 , e <= N) for (i,j), e in np.ndenumerate(H)] # each hook is an integer in 1...n
s += [And(e >= 0, e < N) for e in row_fix + col_fix]  # each row_fix and col_fix varibale is an integer 1..n

s += Distinct(row_fix) # The row_fix and col_fix variables are distinct
s += Distinct(col_fix)

for n in range(1, N+1): # For each of the hooks 1...n
    s += PbEq([(e == n, 1) for _ , e in np.ndenumerate(H)], 2 * n - 1)
    if n != 1:
        s+=PbEq([(If(e==int(n),Sum([H[k,l] == int(n) for k,l in neigh(i,j,N)]),0)==1,1)for (i,j),e in np.ndenumerate(H)],2)    
        
for n in range(N):
    s += [Implies(e == int(n+1),Or(i == row_fix[n],j == col_fix[n])) for (i,j),e in np.ndenumerate(H)]
    
# X = nums  
X = np.array(IntVector("x",N**2),dtype=object).reshape((N,N))    
hook_num = IntVector("n",N)

s += [And(e>0,e<N+1) for e in hook_num] # hook nums range from 1...N
s += [And(e>0,e<N+1) for (i,j),e in np.ndenumerate(X)] # Each non-zero entry ranges from 1...N
s += Distinct(hook_num) # The hook numbers must be distinct/ one of each

for n in range(N):
    s += [Implies(e==n+1,X[i,j]==hook_num[n]) for (i,j),e in np.ndenumerate(H)] # X entries must match their hook number

# Y = filled    
Y = np.array(IntVector("y",N**2),dtype=object).reshape((N,N)) # define Y va
s += [e == 0 if (i,j) in cons.keys() else Or(e==0,e==1) for (i,j),e in np.ndenumerate(Y)]
for n in range(1,N+1):
    s+=PbEq([(And(e==n,Y[i,j]==1),1) for (i,j),e in np.ndenumerate(X)],n)
    
for (i,j), v in cons.items():
    s += [sum(X[a,b] * Y[a,b] for a,b in neigh(i,j,N)) == v]
###
# no filled 2x2 Constraint
s += [Or(Y[i,j] ==0,Y[i+1,j] ==0,Y[i,j+1] ==0,Y[i+1,j+1] ==0) for j in range(N-1) for i in range(N-1)]


################### Set up connectivity constraints ################################################################
edge ={}
Z = np.array(IntVector('z', N*N),dtype=object).reshape(N,N)

# create edges and variable to define order in connectivity tree
for i in range(N):
    for j in range(N):  
        for (k,l) in neigh(i,j,N):
            edge[(i,j,k,l)] = Int("e%d%d%d%d" % (i,j,k,l))
            s+=Or(edge[(i,j,k,l)] ==0,edge[(i,j,k,l)] ==1)

# no edges into or out of unfilled cells. Only one edge per cell     
for i in range(N):
    for j in range(N):       
        s += Implies(Y[i,j] ==0 , Sum([edge[(k,l,i,j)] for (k,l) in neigh(i,j,N)]) == 0)
        s += Implies(Y[i,j] ==0 , Sum([edge[(i,j,k,l)] for (k,l) in neigh(i,j,N)]) == 0)
        s += Implies(Y[i,j] > 0 , Sum([edge[(k,l,i,j)] for (k,l) in neigh(i,j,N)]) <= 1)

        for (k,l) in neigh(i,j,N):
            # Make edges one way to form a tree
            s += (edge[(i,j,k,l)] + edge[(k,l,i,j)]) <= 1

    # limit the order values
    s += [And(e>= 0 ,e <=int(N*N)) for (i,j),e in np.ndenumerate(Z)]

    # order is ascending along edges
    for i in range(N):
        for j in range(N):            
            s += [Implies(And(Y[i,j]!=0,edge[(k,l,i,j)]==1),Z[i,j] > Z[k,l]) for (k,l) in neigh(i,j,N)]

    # only one cell with no feed in => root
    s += PbEq([(And(Y[i,j] != 0,Sum([edge[(k,l,i,j)] for (k,l) in neigh(i,j,N) ]) == 0),1) for i in range(N) for j in range(N)],1)
################### End of  connectivity constraints ################################################################

                
print("setup done in {:.2f} secs".format(time.time()-start))
stop = False
iteration = 0
while stop == False:
    iteration += 1
    if iteration % 1000 == 0:
        print(f'On iteration {iteration}')
    if s.check() == sat:
        
        m = s.model()
        evalu = np.vectorize(lambda x:m.evaluate(x).as_long())
        result_x =  evalu(X) # The filled numbers
        result_y =  evalu(Y) # The filled entries
        
        if  test_solution(result_x,result_y, cons):
            grid_print(result_y,result_x,cons)
            print("Took {:.4f} seconds".format(time.time()- start))
            soln = result_x*result_y;
            print("solution is {:,.0f}".format(areas(soln)))
            print("Connectivity: ",np.max(label(soln>0,connectivity=1)) == 1)
            stop=True
        else:
            s+=Or([X[i,j]*Y[i,j] != int(result_x[i,j]*result_y[i,j]) for (i,j),e in np.ndenumerate(X)])
setup done in 2.44 secs
Took 20.8924 seconds
[5. 7. 1. 1. 1. 6. 1. 1. 2. 5. 2. 1. 1. 2.]
solution is 8,400
Connectivity:  True
/var/folders/8t/4cb18tcd3m52m3rjj24wt43c0000gn/T/ipykernel_48544/4101460493.py:20: DeprecationWarning: Please import `label` from the `scipy.ndimage` namespace; the `scipy.ndimage.measurements` namespace is deprecated and will be removed in SciPy 2.0.0.
  labels, num = measurements.label(np.logical_not(grid!=0))
/var/folders/8t/4cb18tcd3m52m3rjj24wt43c0000gn/T/ipykernel_48544/4101460493.py:21: DeprecationWarning: Please import `sum` from the `scipy.ndimage` namespace; the `scipy.ndimage.measurements` namespace is deprecated and will be removed in SciPy 2.0.0.
  areas = measurements.sum(np.logical_not(grid!=0), labels, index=range(1, num+1))