Skip to content

generalize multiplication_circuit to target multiple backend implementations #1187

@mhlr

Description

@mhlr

If there is interest add this I can make a PR. I have the basic code below.

Application
Multiplication circuits have a common overall structure but could target different implementations of basic constraints, eg the CQM or a translation to k-SAT before encoding to QUBO

Proposed Solution

Factor out the large scale multiplication circuit structure from the implementation of the basic building blocks using the Builder pattern.
Something like:

import functools as ft
import itertools as it

from dimod import CQM, Binary
from dwave.system import LeapHybridCQMSampler
from dimod.generators import gates


def build_multiplier(builder, num_arg1_bits, num_arg2_bits = None):

    if num_arg1_bits < 1:
        raise ValueError("num_arg1_bits must be a positive integer")

    num_arg2_bits = num_arg2_bits or num_arg1_bits

    if num_arg2_bits < 1:
        raise ValueError("the arg2 must have a positive size")

    num_product_bits = num_arg1_bits + num_arg2_bits

    # throughout, we will use the following convention:
    #   i to refer to the bits of arg1
    #   j to refer to the bits of arg2

    components = dict(
        enumerate([builder.and_gate, builder.half_adder, builder.full_adder], 2)
    )

    def SUM(i, j):
        return (
            f"p{i}"
            if j == 0
            else (f"p{i + j}" if i == num_arg1_bits - 1 else f"sum{i},{j}")
        )

    def CARRY(i, j):
        return (
            f"p{num_product_bits - 1}"
            if i + j == num_product_bits - 2
            else f"carry{i},{j}"
        )

    def connections(i, j):
        inputs = [f"a{i}", f"b{j}"]
        if i > 0:
            if j < num_arg2_bits - 1:
                inputs.append(SUM(i - 1, j + 1))
            elif i > 1:
                inputs.append(CARRY(i - 1, j))
            if j > 0:
                inputs.append(CARRY(i, j - 1))
        outputs = [SUM(i, j)]
        if i > 0:
            outputs.append(CARRY(i, j))
        return inputs, outputs

    for inputs, outputs in it.starmap(connections, it.product(range(num_arg1_bits), range(num_arg2_bits))):
        components[len(inputs)](*map(Binary, inputs+outputs))
        
    return builder.model()

class MultiplicationCQM:
    def __init__(self, model=None):
        self.model_ = model or CQM()

    
    def and_gate(self, a, b, out):
        self.model_.add_constraint(a * b - out == 0, f"{a} and {b} == {out}")
 
    def half_adder(self, a, b, sum_in, sum_out, carry_out):
        self.model_.add_constraint(a * b + sum_in - (2*carry_out + sum_out) == 0,
                           f"{a} * {b} + {sum_in} == 2*{carry_out} + {sum_out}")

    def full_adder(self, a, b, sum_in, carry_in, sum_out, carry_out):
        self.model_.add_constraint(a * b + sum_in + carry_in - (2*carry_out + sum_out) == 0,
                           f"{a} * {b} + {sum_in} + {carry_in} == 2*{carry_out} + {sum_out}")

    def model(self):
        return self.model_

    
class MultiplicationGates:
    
    def __init__(self, model=None):
        self.model_ = model or BQM()
        srlf.components_ = []
   
    def and_gate(self, a, b, out):
        self.components_.append(gates.and_gate(a, b, out))
        return out
  
    def half_adder(self, a, b, *args):
        self.components_.append(gates.half_adder(self.and_gate(a, b, f'and{a},{b}'), *args))

    def full_adder(self, a, b, *args):
        self.components_.append(gates.full_adder(self.and_gate(a, b, f'and{a},{b}'), *args))

    def model(self):
        return quicksum(self.components_)

                                
def and_model(n_aux=0):
    return maxgap_model(
        {(u, v, u * v): 0 for u, v in it.product((0, 1), repeat=2)}, n_aux
    )

def fulladder_model(n_aux=0):
    S =  Specification(nx.complete_graph(6+n_aux), tuple(range(6)), {(u, v, y, z, (s:=u*v+y+z)%2 , s>1): 0 for u,v,y,z in it.product((0,1), repeat=4)}, dimod.BINARY,
                    min_classical_gap=1e-12)
    return mg.get_penalty_model(S)

def halfadder_model(n_aux=0):
    S =  Specification(nx.complete_graph(5+n_aux), tuple(range(5)), {(u, v, y, (s:=u*v+y)%2 , s>1): 0 for u,v,y,z in it.product((0,1), repeat=4)}, dimod.BINARY,
                    min_classical_gap=1e-12)
    return mg.get_penalty_model(S)

def relabel_gate(gate, labels):
    assert len(labels) <= gate.num_variables
    return gate.relabel_variables(
        (dict(
            enumerate(
                labels+tuple(uuid.uuid4().hex for _ in range(gate.num_variables - len(labels)))
            )
        )),
        inplace=False
    )


class MultiplicationPenaltymodel:
    
    def __init__(self, model=None):
        self.model_ = model or BQM()
        self.and_ = and_model().model
        self.half_ = halfadder_model().model
        self.full_ = fulladder_model().model
        self.components_ = []

    def and_gate(self, *labels):
        assert len(labels) == 3
        self.components.append(relabel_gate(self.and_, labels))

    def half_adder(self, *labels):
        assert len(labels)==5
        self.components.append(relabel_gate(self.half_, labels))

    def full_adder(self, *labels):
        assert len(labels)==6
        self.components.append(relabel_gate(self.full_, labels))

    def model(self):
        return quicksum(self.components_)

class MultiplySAT:
    ## TODO translate to  SAT clauses befoe encoding on CQM
    pass

def bits2int(bits):
    bits = list(bits)
    res =  ft.reduce(lambda x, y: 2 * x + y, bits, 0)
    #ic(bits, res)
    assert res == int(res)
    assert 0 <= res <= 2**len(bits)
    return res


def factor_vars(factor, sample):
    var_len = len(factor)
    res = sorted((var for var in sample.keys() if var.startswith(factor) and var[var_len:]),
                  key=lambda x: int(x[var_len:]),
                  reverse=True)
    #ic(factor, sample, res)
    return res


def factor_val(factor, sample):
    res =  bits2int(sample[v] > 0 for v in factor_vars(factor, sample))
    #ic(factor, sample, res)
    return res

#%%time
#P=35

# https://primes.utm.edu/lists/2small/0bit.html
# A = 2**32 - 5
# B = 2**32 - 17

# A = 2**10 - 3
# B = 2**10 - 5

#A = 2**8 - 5
#B = 2**8 - 15

A = 33
B = 47

#A=13
#B=11

#A=29
#B=31

#A=7
#B=5

time_limit = 120
P = A * B
print(A, B, P)

def fixed_vars(var, val):
    bits = "{0:b}".format(val)
    bit_vars = list(reversed([f"{var}{i}" for i in range(len(bits))]))
    return dict(zip(bit_vars, map(int, bits)))

p_vars = fixed_vars('p', P)
a_vars = fixed_vars('a', A)
b_vars = fixed_vars('b', B)

nbits1 = len(a_vars)
nbits2 = len(b_vars)


# Convert P from decimal to binary
builder = MultiplicationCQM()
cqm = build_multiplier(builder, nbits1, nbits2)

# Fix product variables
for var, value in it.chain(p_vars.items(),
                           # a_vars.items(),
                           # b_vars.items()
                           ):
    cqm.fix_variable(var, value)

sampler = LeapHybridCQMSampler()
print(sampler.solver.name)

sampleset = sampler.sample_cqm(cqm, time_limit=time_limit)

#assert (sampleset.to_pandas_dataframe(True).is_feasible.sum() > 0) == (P == A * B)

df = sampleset.aggregate().to_pandas_dataframe(True)
print(df.is_feasible.sum(), "feasible samples")

for sample in df['sample'][df.is_feasible]:
    aa, bb = (factor_val(v, sample) for v in 'ab')
    print(F"Success: {P} => {aa} * {bb} == {aa*bb}")

Additional Context

This would simplify dwave-examples/factoring-notebook#22

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions