AES

# reference: https://github.com/secworks/aes 

from pyhgl.logic import *
from pyhgl.tester import *  
import pyhgl.logic.utils as utils
import time

# multiplication under Rijndael’s finite field
# --------------------------------------------
def gm2(op): return (op << 1) ^ (op[-1]**8 & '8:h1b') 

def gm3(op): return gm2(op) ^ op  

def gm4(op): return gm2(gm2(op))

def gm8(op): return gm2(gm4(op))

def gm09(op): return gm8(op) ^ op 

def gm11(op): return gm8(op) ^ gm2(op) ^ op 

def gm13(op): return gm8(op) ^ gm4(op) ^ op 

def gm14(op): return gm8(op) ^ gm4(op) ^ gm2(op)

def mixw(w):
    """
    2 3 1 1     a0        b0
    1 2 3 1     a1   ->   b1
    1 1 2 3  *  a2        b2
    3 1 1 2     a3        b3 

    input:  array of 4 bytes: [a3, a2, a1, a0] 
    return: array of 4 bytes: [b3, b2, b1, b0]
    """
    a3, a2, a1, a0 = w  
    b0 = gm2(a0) ^ gm3(a1) ^ a2      ^ a3
    b1 = a0      ^ gm2(a1) ^ gm3(a2) ^ a3
    b2 = a0      ^ a1      ^ gm2(a2) ^ gm3(a3)
    b3 = gm3(a0) ^ a1      ^ a2      ^ gm2(a3)
    return Array([b3, b2, b1, b0]) 

def inv_mixw(w):
    a3, a2, a1, a0 = w 
    b0 = gm14(a0) ^ gm11(a1) ^ gm13(a2) ^ gm09(a3) 
    b1 = gm09(a0) ^ gm14(a1) ^ gm11(a2) ^ gm13(a3) 
    b2 = gm13(a0) ^ gm09(a1) ^ gm14(a2) ^ gm11(a3) 
    b3 = gm11(a0) ^ gm13(a1) ^ gm09(a2) ^ gm14(a3) 
    return Array([b3, b2, b1, b0])

# a:  Array of 16 bytes
def mixcolumns(a):  return Array(mixw(i) for i in a._reshape(4,4))._flat  

def inv_mixcolumns(a): return Array(inv_mixw(i) for i in a._reshape(4,4))._flat 

def shiftrows(a): 
    """
    input: 
        a00 a01 a02 a03 
        a10 a11 a12 a13 
        a20 a21 a22 a23 
        a30 a31 a32 a33
    """ 
    a = a._reshape(4,4)
    return  Array([ a[1,0], a[2,1], a[3,2], a[0,3],
                    a[2,0], a[3,1], a[0,2], a[1,3],
                    a[3,0], a[0,1], a[1,2], a[2,3],
                    a[0,0], a[1,1], a[2,2], a[3,3],]) 

def inv_shiftrows(a): 
    a = a._reshape(4,4)
    return Array([  a[3,0], a[2,1], a[1,2], a[0,3],
                    a[0,0], a[3,1], a[2,2], a[1,3],
                    a[1,0], a[0,1], a[3,2], a[2,3],
                    a[2,0], a[1,1], a[0,2], a[3,3],])


def addroundkey(data, rkey): return data ^ rkey 

@module aes_encipher_block:
    io = Bundle(
        round       = UInt('4:b0')          @ Output,   # nth round
        new_block   = UInt[8].zeros(16)     @ Output,
        ready       = UInt('0')             @ Output,
        sboxw       = UInt[8].zeros(4)      @ Output,
    )  

    sword_ctr_reg = Reg(UInt('2:b0'))       # word counter, read sbox 4 cycles
    round_ctr_reg = Reg(io.round)           # round counter 

    block_reg = Reg(io.new_block) 
    ready_reg = Reg(io.ready) 
    
    switch update_type:=Enum['NO_UPDATE',...]():
        once 'INIT_UPDATE':
            block_reg <== addroundkey(conf.p.block, conf.p.round_key) 
        once 'SBOX_UPDATE':
            switch sword_ctr_reg:
                once 0: io.sboxw,block_reg[12:16] <== block_reg[12:16],conf.p.sbox.new_sboxw
                once 1: io.sboxw,block_reg[8:12]  <== block_reg[8:12],conf.p.sbox.new_sboxw 
                once 2: io.sboxw,block_reg[4:8]   <== block_reg[4:8],conf.p.sbox.new_sboxw 
                once 3: io.sboxw,block_reg[0:4]   <== block_reg[0:4],conf.p.sbox.new_sboxw 
        once 'MAIN_UPDATE':
            block_reg <== addroundkey(mixcolumns(shiftrows(block_reg)), conf.p.round_key) 
        once 'FINAL_UPDATE':
            block_reg <== addroundkey(shiftrows(block_reg), conf.p.round_key) 

    switch state:=Reg(Enum()):
        once 'CTRL_IDLE':  
            when conf.p.next:
                round_ctr_reg, ready_reg <== 0 
                state <==  'CTRL_INIT'
        once 'CTRL_INIT': 
            round_ctr_reg <== round_ctr_reg + 1  
            sword_ctr_reg <== 0
            update_type <== 'INIT_UPDATE'  
            state <== 'CTRL_SBOX'
        once 'CTRL_SBOX': 
            sword_ctr_reg <== sword_ctr_reg + 1 
            update_type <==  'SBOX_UPDATE' 
            when sword_ctr_reg == 3:
                state <== 'CTRL_MAIN'
        once 'CTRL_MAIN': 
            sword_ctr_reg <== 0  
            round_ctr_reg <== round_ctr_reg + 1 
            when round_ctr_reg < conf.p.num_rounds:
                update_type <== 'MAIN_UPDATE' 
                state <== 'CTRL_SBOX' 
            otherwise:
                update_type <== 'FINAL_UPDATE'
                ready_reg <== 1 
                state <== 'CTRL_IDLE'


@module aes_decipher_block:
    io = Bundle(
        round       = UInt('4:b0')          @ Output,   # nth round
        new_block   = UInt[8].zeros(16)     @ Output,
        ready       = UInt('0')             @ Output,
    )

    sword_ctr_reg = Reg(UInt('2:b0'))   # word counter, read sbox 4 cycles
    round_ctr_reg = Reg(io.round)       # round counter, decrease because use later keys 

    block_reg = Reg(io.new_block) 
    ready_reg = Reg(io.ready)
    
    inv_sbox = aes_inv_sbox() 

    switch update_type:=Enum['NO_UPDATE',...]():
        once 'INIT_UPDATE':
            block_reg <== inv_shiftrows(addroundkey(conf.p.block, conf.p.round_key)) 
        once 'SBOX_UPDATE':
            switch sword_ctr_reg:
                once 0: inv_sbox.sboxw,block_reg[12:16] <== block_reg[12:16],inv_sbox.new_sboxw
                once 1: inv_sbox.sboxw,block_reg[8:12]  <== block_reg[8:12],inv_sbox.new_sboxw
                once 2: inv_sbox.sboxw,block_reg[4:8]   <== block_reg[4:8],inv_sbox.new_sboxw
                once 3: inv_sbox.sboxw,block_reg[0:4]   <== block_reg[0:4],inv_sbox.new_sboxw 
        once 'MAIN_UPDATE':
            block_reg <== inv_shiftrows(inv_mixcolumns(addroundkey(block_reg, conf.p.round_key))) 
        once 'FINAL_UPDATE':
            block_reg <== addroundkey(block_reg, conf.p.round_key)
        
    switch state:=Reg(Enum()):
        once 'CTRL_IDLE':
            when conf.p.next:
                round_ctr_reg <== conf.p.num_rounds 
                ready_reg <== 0 
                state <== 'CTRL_INIT'
        once 'CTRL_INIT':
            sword_ctr_reg <== 0 
            update_type <== 'INIT_UPDATE' 
            state <== 'CTRL_SBOX'
        once 'CTRL_SBOX': 
            sword_ctr_reg <== sword_ctr_reg + 1 
            update_type <== 'SBOX_UPDATE'  
            when sword_ctr_reg == 3:
                round_ctr_reg <== round_ctr_reg - 1 
                state <== 'CTRL_MAIN'
        once 'CTRL_MAIN': 
            sword_ctr_reg <== 0 
            when round_ctr_reg > 0:
                update_type <== 'MAIN_UPDATE'  
                state <== 'CTRL_SBOX' 
            otherwise:
                update_type <== 'FINAL_UPDATE'
                ready_reg <== 1 
                state <== 'CTRL_IDLE'

@module aes_core:
    io = Bundle(
        encdec          = UInt('0')         @Input, 
        init            = UInt('0')         @Input,
        next            = UInt('0')         @Input,
        ready           = UInt('0')         @Output,

        key             = UInt('256:b0')    @Input, 
        keylen          = UInt('0')         @Input,
        block           = UInt('128:b0')    @Input, 
        result          = UInt('128:b0')    @Output,
        result_valid    = UInt('0')         @Output,
    )

    conf.p.init = io.init                       # pass to aes_key_mem  
    conf.p.next = io.next                       # pass to enc/dec block 
    conf.p.num_rounds = Mux(io.keylen, 14, 10)  # pass to enc/dec block 
    conf.p.keylen = io.keylen                   # pass to aes_key_mem 
    conf.p.sbox = aes_sbox()                    # sbox is accessed by both enc_block and aes_key_mem
    conf.p.key = io.key.split(8)                # pass to aes_key_mem

    result_valid_reg = Reg(io.result_valid)
    ready_reg = Reg(io.ready)
    
    keymem = aes_key_mem()                      # key expension

    conf.p.round_key = keymem.io.round_key      # pass to enc/dec block
    conf.p.block = io.block.split(8)            # pass to enc/dec block

    enc_block = aes_encipher_block()
    dec_block = aes_decipher_block() 

    init_state = UInt(0)

    # sbox_mux
    conf.p.sbox.sboxw <== Mux(init_state, keymem.io.sboxw, enc_block.io.sboxw)

    # encdex_mux  
    keymem.io.round <== Mux(io.encdec, enc_block.io.round, dec_block.io.round)
    io.result       <== Mux(io.encdec, Cat(*enc_block.io.new_block), Cat(*dec_block.io.new_block))
    muxed_ready     =   Mux(io.encdec, enc_block.io.ready, dec_block.io.ready)


    switch state:=Reg(Enum()):
        once 'CTRL_IDLE': 
            when io.init:
                init_state <== 1 
                ready_reg,result_valid_reg  <== 0 
                state <== 'CTRL_INIT' 
            elsewhen io.next:
                init_state, ready_reg, result_valid_reg  <== 0 
                state <== 'CTRL_NEXT'
        once 'CTRL_INIT': 
            init_state <== 1 
            when  keymem.io.ready:
                ready_reg <== 1 
                state <== 'CTRL_IDLE'
        once 'CTRL_NEXT': 
            init_state <== 0 
            when muxed_ready:
                ready_reg, result_valid_reg <== 1 
                state <== 'CTRL_IDLE'

@module aes_key_mem:
    io = Bundle(
        round       = UInt('4:b0')          @ Input,   # input nth round
        round_key   = UInt[8].zeros(16)     @ Output,  # output nth key
        ready       = UInt('0')             @ Output,  # init ready
        sboxw       = UInt[8].zeros(4)      @ Output,  # result
    )
    
    # 16 x Vec[15,8]; 16 columns of bytes
    key_mem = Reg( (15*UInt[8])(0) for _ in range(16) )   
    # output
    io.round_key <== (key_mem_column[io.round] for key_mem_column in key_mem)

    key_mem_new   = WireNext(conf.p.key[16:])    
    round_ctr_reg = Reg(UInt('4:b0'))           # round counter, nth round
    rcon_reg = Reg(UInt('8:b0'))                # round constant 
    ready_reg = Reg(io.ready)

    round_key_update = UInt(0) 

    when round_key_update:                      # store generated keys
        for odd_mem, new_mem in zip(key_mem, key_mem_new):
            odd_mem[round_ctr_reg] <== new_mem

    # update rcon
    rcon_next = UInt(0) 
    when !round_key_update:         # init rcon
        rcon_reg <== '8:h8d'
    when rcon_next:                 # next rcon
        rcon_reg <== gm2(rcon_reg)
    
    k0, k1, k2, k3 = UInt[8].zeros(4,4) 
    w3, w2, w1, w0 = RegNext(conf.p.key[16:])._reshape(4,4)
    w7, w6, w5, w4 = RegNext(conf.p.key[16:])._reshape(4,4)  
    io.sboxw <== w7  
    tw = conf.p.sbox.new_sboxw
    trw = Array([ tw[3], tw[0], tw[1], tw[2] ^ rcon_reg]) 

    when round_key_update: 
        rcon_next <== 1
        when conf.p.keylen == 0:                # AES128
            when round_ctr_reg != 0:
                k0, k1, k2, k3 <== w4^trw, w5^w4^trw, w6^w5^w4^trw, w7^w6^w5^w4^trw
                key_mem_new <== *k3, *k2, *k1, *k0
                w7, w6, w5, w4 <== k3, k2, k1, k0 
        otherwise:                              # AES256
            when round_ctr_reg == 0:
                rcon_next <== 0
            elsewhen round_ctr_reg == 1:
                key_mem_new <== conf.p.key[:16] 
                w7, w6, w5, w4 <== conf.p.key[:16]._reshape(4,4) 
            otherwise:
                when round_ctr_reg[0] == 0:
                    k0, k1, k2, k3 <== w0^trw, w1^w0^trw, w2^w1^w0^trw, w3^w2^w1^w0^trw
                    rcon_next <== 0
                otherwise:
                    k0, k1, k2, k3 <== w0^tw, w1^w0^tw, w2^w1^w0^tw, w3^w2^w1^w0^tw
                key_mem_new <== *k3, *k2, *k1, *k0 
                w7, w6, w5, w4, w3, w2, w1, w0 <== k3, k2, k1, k0, w7, w6, w5, w4 

    switch state:=Reg(Enum()):
        once 'IDLE': 
            when conf.p.init:
                ready_reg <== 0 
                state <== 'INIT'
        once 'INIT': 
            round_ctr_reg <== 0 
            state <== 'GENERATE' 
        once 'GENERATE': 
            round_ctr_reg <== round_ctr_reg + 1 
            round_key_update <== 1 
            when round_ctr_reg == conf.p.num_rounds:
                state <== 'DONE' 
        once 'DONE': 
            ready_reg <== 1 
            state <== 'IDLE'


@module aes_sbox: 
    """ input: 4 bytes
    """
    sboxw = UInt(['8:b0'] * 4)
    sbox = Vector(256, UInt[8])([
        0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
        0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
        0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
        0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
        0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
        0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
        0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
        0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
    ]) 
    new_sboxw = Array(sbox[i] for i in sboxw)


@module aes_inv_sbox: 
    sboxw = UInt(['8:b0'] * 4)
    sbox = Vector(256, UInt[8])([
        0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb, 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
        0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e, 0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
        0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92, 0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
        0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06, 0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
        0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73, 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
        0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b, 0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
        0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f, 0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
        0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61, 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d,
    ]) 
    new_sboxw = Array(sbox[i] for i in sboxw)


@task test_enc_dec(self, dut, mode, key, text, expected):
    assert mode in ['128enc','128dec','256enc','256dec']
    mode_encode = (mode[-3:] == 'enc')
    mode_256 = (mode[:3] == '256') 
    setv(dut.io.encdec, mode_encode)
    setv(dut.io.keylen, mode_256)
    setv(dut.io.key, key)
    setv(dut.io.block, text)
    setv(dut.io.init, 1)
    yield self.clock_n(2)
    setv(dut.io.init, 0) 
    yield dut.io.ready 
    setv(dut.io.next, 1)
    yield self.clock_n(2)
    setv(dut.io.next, 0)
    yield dut.io.ready
    result = getv(dut.io.result) 
    self.AssertEq(result, expected)
    print(f"mode:{mode}, key:{hex(key)}, text:{hex(text)}, expected:{hex(expected)}, result:{hex(result.v)}")


@tester test_AES(self):
    sess = Session()
    sess.enter() 
    conf.reset = (conf.reset[0], 0)

    t = time.time()
    dut = aes_core()  
    print('pyhgl build: ', time.time()-t)

    aes128_key1 = 0x2b7e151628aed2a6abf7158809cf4f3c00000000000000000000000000000000
    aes128_key2 = 0x000102030405060708090a0b0c0d0e0f00000000000000000000000000000000
    aes256_key1 = 0x603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4
    aes256_key2 = 0x000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f 

    plaintext0 = 0x6bc1bee22e409f96e93d7e117393172a
    plaintext1 = 0xae2d8a571e03ac9c9eb76fac45af8e51
    plaintext2 = 0x30c81c46a35ce411e5fbc1191a0a52ef
    plaintext3 = 0xf69f2445df4f9b17ad2b417be66c3710
    plaintext4 = 0x00112233445566778899aabbccddeeff

    _128_enc_expected0 = 0x3ad77bb40d7a3660a89ecaf32466ef97
    _128_enc_expected1 = 0xf5d3d58503b9699de785895a96fdbaaf
    _128_enc_expected2 = 0x43b1cd7f598ece23881b00e3ed030688
    _128_enc_expected3 = 0x7b0c785e27e8ad3f8223207104725dd4
    _128_enc_expected4 = 0x69c4e0d86a7b0430d8cdb78070b4c55a

    _256_enc_expected0 = 0xf3eed1bdb5d2a03c064b5a7e3db181f8
    _256_enc_expected1 = 0x591ccb10d410ed26dc5ba74a31362870
    _256_enc_expected2 = 0xb6ed21b99ca6f4f9f153e7b1beafed1d
    _256_enc_expected3 = 0x23304b7a39f9f3ff067d8d8f9e24ecc7
    _256_enc_expected4 = 0x8ea2b7ca516745bfeafc49904b496089

    @task aes_tests(self):
        yield self.reset()
        yield test_enc_dec(dut, '128enc', aes128_key1, plaintext0, _128_enc_expected0)
        yield test_enc_dec(dut, '128enc', aes128_key1, plaintext1, _128_enc_expected1)
        yield test_enc_dec(dut, '128enc', aes128_key1, plaintext2, _128_enc_expected2)
        yield test_enc_dec(dut, '128enc', aes128_key1, plaintext3, _128_enc_expected3)
        yield test_enc_dec(dut, '128enc', aes128_key2, plaintext4, _128_enc_expected4)
        yield test_enc_dec(dut, '128dec', aes128_key1, _128_enc_expected0, plaintext0)
        yield test_enc_dec(dut, '128dec', aes128_key1, _128_enc_expected1, plaintext1)
        yield test_enc_dec(dut, '128dec', aes128_key1, _128_enc_expected2, plaintext2)
        yield test_enc_dec(dut, '128dec', aes128_key1, _128_enc_expected3, plaintext3)
        yield test_enc_dec(dut, '128dec', aes128_key2, _128_enc_expected4, plaintext4)

        yield test_enc_dec(dut, '256enc', aes256_key1, plaintext0, _256_enc_expected0)
        yield test_enc_dec(dut, '256enc', aes256_key1, plaintext1, _256_enc_expected1)
        yield test_enc_dec(dut, '256enc', aes256_key1, plaintext2, _256_enc_expected2)
        yield test_enc_dec(dut, '256enc', aes256_key1, plaintext3, _256_enc_expected3)
        yield test_enc_dec(dut, '256enc', aes256_key2, plaintext4, _256_enc_expected4)
        yield test_enc_dec(dut, '256dec', aes256_key1, _256_enc_expected0, plaintext0)
        yield test_enc_dec(dut, '256dec', aes256_key1, _256_enc_expected1, plaintext1)
        yield test_enc_dec(dut, '256dec', aes256_key1, _256_enc_expected2, plaintext2)
        yield test_enc_dec(dut, '256dec', aes256_key1, _256_enc_expected3, plaintext3)
        yield test_enc_dec(dut, '256dec', aes256_key2, _256_enc_expected4, plaintext4)

    sess.dumpVerilog('AES.sv', delay=True, top = False) 

    sess.track(dut, dut.dec_block, dut.keymem, dut.enc_block) 

    t = time.time()
    sess.join(aes_tests())
    print('pyhgl sim:', time.time()-t)

    print(sess)
    sess.dumpVCD('AES.vcd') 
    sess.dumpGraph('AES.gv')
    sess.exit()