WallaceMultiplier

from pyhgl.logic import * 

import random
import time 
import itertools


def compressN(column: list, sums: list, couts: list):
    for i in range(0, len(column), 3):
        x = column[i:i+3]  
        if len(x) == 1:
            sums.append(x[0])
        elif len(x) == 2:
            sums.append(Xor(*x))
            couts.append(And(*x))
        else:
            sums.append(Xor(*x))
            couts.append(x[0] & x[1] | x[2] & (x[0] ^ x[1]))


@module WallaceMultiplier(w1: int, w2: int):
    io = Bundle(
        x = Input(UInt(0, w1)),
        y = Input(UInt(0, w2)),
        out = Output(UInt(0, w1+w2)),      
    )
    products = [[] for _ in range(w1+w2)] 
    for (i, x_i), (j, y_j) in itertools.product(enumerate(io.x.split()), enumerate(io.y.split())):
        products[i+j].append(x_i & y_j) 


    while max(len(i) for i in products) >= 3: 
        products_new = [[] for _ in range(len(products)+1)] 
        for i, column in enumerate(products):
            compressN(column, products_new[i], products_new[i+1])
        products = products_new

    while len(products[-1]) == 0:
        products.pop()

    a = Cat(*[i[0] for i in products])
    b = Cat(*[i[1] if len(i) == 2 else UInt(0) for i in products]) 
    io.out <== a + b


#--------------------------------- test ----------------------------------------

@task mult_test(self, dut, data): 
    out_mask = Logic( (1 << len(dut.io.out)) - 1, 0) 
    for x, y in data:
        setv(dut.io.x, x)
        setv(dut.io.y, y)
        yield self.clock_n()
        self.AssertEq(getv(dut.io.out), (x * y) & out_mask)


with Session() as sess:
    W, N = 16, 2000
    ratio = 0.0
    dut = WallaceMultiplier(W, W) 
    sess.track(dut.io, dut.a, dut.b)
    test_data = [(Logic.rand(W,W, ratio=ratio), Logic.rand(W,W, ratio=ratio)) for _ in range(N)]

    sess.join(mult_test(dut, test_data))
    sess.dumpVCD('Multiplier.vcd')

    sess.dumpVerilog('Multiplier.sv', delay=True, top = True) 
    print(sess)