Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[flake8]
max-line-length = 88
extend-ignore = E203
51 changes: 51 additions & 0 deletions mcl/bindings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import ctypes
import inspect

from . import hook, utils

MCL_NAME = "__mcl_name__"


def is_mcl_method(obj):
return inspect.ismethod(obj) and hasattr(obj, MCL_NAME)


def method_binding(method_name=None):
def decorator(method_obj):
setattr(method_obj, MCL_NAME, method_name or method_obj.__name__)
return method_obj

return decorator


def _build_mcl_ctypes_binding(method_obj, mcl_method_name, klass_instance):
wrapper = utils.wrap_function(
hook.mclbls12_384,
mcl_method_name,
[
ctypes.POINTER(klass_instance),
ctypes.c_char_p,
ctypes.c_size_t,
ctypes.c_int64,
],
)
return wrapper


def class_binding():
def decorator(klass_instance):
method_name_obj_pairs = inspect.getmembers(
klass_instance, predicate=is_mcl_method
)

for method_name, method_obj in method_name_obj_pairs:
mcl_method_name = getattr(method_obj, MCL_NAME)
setattr(
klass_instance,
method_name,
_build_mcl_ctypes_binding(method_obj, mcl_method_name, klass_instance),
)

return klass_instance

return decorator
26 changes: 19 additions & 7 deletions mcl/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import dataclasses
import inspect

from . import hook
from . import utils
from . import hook, utils

BUFFER_SIZE = 2048


def tryGetBuilderMethodFromGlobals(method_name: str) -> callable:
def tryGetBuilderMethodFromGlobals(method_name: str):
return globals().get("build" + method_name[0].upper() + method_name[1:])


Expand Down Expand Up @@ -51,7 +50,12 @@ def buildSetStr(cls):
wrapper = utils.wrap_function(
hook.mclbls12_384,
f"mclBn{cls.__name__}_setStr",
[ctypes.POINTER(cls), ctypes.c_char_p, ctypes.c_size_t, ctypes.c_int64],
[
ctypes.POINTER(cls),
ctypes.c_char_p,
ctypes.c_size_t,
ctypes.c_int64,
],
)

def setStr(self, value, mode=10):
Expand All @@ -75,7 +79,9 @@ def setInt(self, value):

def buildSetByCSPRNG(cls):
wrapper = utils.wrap_function(
hook.mclbls12_384, f"mclBn{cls.__name__}_setByCSPRNG", [ctypes.POINTER(cls)],
hook.mclbls12_384,
f"mclBn{cls.__name__}_setByCSPRNG",
[ctypes.POINTER(cls)],
)

def setByCSPRNG(self):
Expand Down Expand Up @@ -120,7 +126,9 @@ def isEqual(self, other):

def buildIsOne(cls):
wrapper = utils.wrap_function(
hook.mclbls12_384, f"mclBn{cls.__name__}_isOne", [ctypes.POINTER(cls)],
hook.mclbls12_384,
f"mclBn{cls.__name__}_isOne",
[ctypes.POINTER(cls)],
)

def isOne(self, other):
Expand Down Expand Up @@ -248,7 +256,11 @@ def buildPairing(cls, left_group, right_group):
wrapper = utils.wrap_function(
hook.mclbls12_384,
f"mclBn_pairing",
(ctypes.POINTER(cls), ctypes.POINTER(left_group), ctypes.POINTER(right_group)),
(
ctypes.POINTER(cls),
ctypes.POINTER(left_group),
ctypes.POINTER(right_group),
),
)

@staticmethod
Expand Down
8 changes: 6 additions & 2 deletions mcl/structures/Fp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import ctypes

from .. import builder
from .. import consts
from .. import builder, consts


@builder.provide_methods(
Expand All @@ -23,3 +22,8 @@
)
class Fp(ctypes.Structure):
_fields_ = [("v", ctypes.c_ulonglong * consts.FP_SIZE)]

def __repr__(self):
import pdb

pdb.set_trace()
113 changes: 91 additions & 22 deletions mcl/structures/Fr.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,94 @@
import ctypes

from .. import builder
from .. import consts


@builder.provide_methods(
builder.method("__add__").using(builder.buildThreeOp).with_args("add"),
builder.method("__eq__").using(builder.buildIsEqual),
builder.method("__invert__").using(builder.buildTwoOp).with_args("inv"),
builder.method("__mul__").using(builder.buildThreeOp).with_args("mul"),
builder.method("__neg__").using(builder.buildTwoOp).with_args("neg"),
builder.method("__sub__").using(builder.buildThreeOp).with_args("sub"),
builder.method("__truediv__").using(builder.buildThreeOp).with_args("div"),
builder.method("deserialize"),
builder.method("getStr"),
builder.method("isOne"),
builder.method("isZero"),
builder.method("serialize"),
builder.method("setByCSPRNG"),
builder.method("setInt"),
builder.method("setStr"),
)
from .. import hook

mclbn384_256 = hook.mclbls12_384


class Fr(ctypes.Structure):
_fields_ = [("v", ctypes.c_ulonglong * consts.FR_SIZE)]
_fields_ = [("v", ctypes.c_ulonglong * 6)]

def __init__(self, value=None):
if value is None:
return

if isinstance(value, str):
self.setStr(value)
elif isinstance(value, int):
self.setInt(value)

def setInt(self, v):
mclbn384_256.mclBnFr_setInt(ctypes.byref(self.v), v)

def setStr(self, value: str):
value = value if isinstance(value, bytes) else value.encode()
error = mclbn384_256.mclBnFr_setStr(
ctypes.byref(self.v), ctypes.c_char_p(value), len(value), 10
)
if error:
raise RuntimeError("mclBnFr_setStr failed")

def getStr(self):
svLen = 2048
sv = ctypes.create_string_buffer(b"\0" * svLen)
mclbn384_256.mclBnFr_getStr(sv, svLen, ctypes.byref(self.v), 10)
return sv.value.decode()

def isZero(self):
return mclbn384_256.mclBnFr_isZero(ctypes.byref(self.v)) != 0

def isOne(self):
return mclbn384_256.mclBnFr_isOne(ctypes.byref(self.v)) != 0

def setByCSPRNG(self):
return mclbn384_256.mclBnFr_setByCSPRNG(ctypes.byref(self.v))

@classmethod
def random(cls):
value = cls()
value.setByCSPRNG()
return value

def __eq__(self, rhs):
return (
mclbn384_256.mclBnFr_isEqual(ctypes.byref(self.v), ctypes.byref(rhs.v)) != 0
)

def __ne__(self, rhs):
return not (self == rhs)

def __add__(self, rhs):
ret = Fr()
mclbn384_256.mclBnFr_add(
ctypes.byref(ret.v), ctypes.byref(self.v), ctypes.byref(rhs.v)
)
return ret

def __sub__(self, rhs):
ret = Fr()
mclbn384_256.mclBnFr_sub(
ctypes.byref(ret.v), ctypes.byref(self.v), ctypes.byref(rhs.v)
)
return ret

def __mul__(self, rhs):
ret = Fr()
mclbn384_256.mclBnFr_mul(
ctypes.byref(ret.v), ctypes.byref(self.v), ctypes.byref(rhs.v)
)
return ret

def __div__(self, rhs):
ret = Fr()
mclbn384_256.mclBnFr_div(
ctypes.byref(ret.v), ctypes.byref(self.v), ctypes.byref(rhs.v)
)
return ret

def __invert__(self):
ret = Fr()
mclbn384_256.mclBnFr_neg(ctypes.byref(ret.v), ctypes.byref(self.v))
return ret

def __repr__(self):
return f"Fr({self.getStr()})"
5 changes: 2 additions & 3 deletions mcl/structures/G1.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import types
import ctypes
import types

from .. import utils
from .. import builder
from .. import builder, utils
from .Fp import Fp
from .Fr import Fr

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[tool.black]
exclude = 'venv'
47 changes: 38 additions & 9 deletions tests/test_fr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,58 @@
from mcl import Fr


def createRandomFr() -> Fr:
fr = Fr()
fr.setByCSPRNG()
return fr


class FrTests(unittest.TestCase):

def testAdd(self):
self.assertEqual(Fr(30), Fr(10) + Fr(20))

def testEquals(self):
self.assertEqual(Fr(20), Fr(20))

def testInvert(self):
self.assertEqual(Fr(10), ~~Fr(10))

def testMul(self):
self.assertEqual(Fr(8), Fr(4) * Fr(2))

def testInitFr(self):
self.assertIsNotNone(Fr())

def testSetStr(self):
Fr().setStr(b"12345678901234567")
# Arrange.
expected = "1234567"
fr = Fr()

# Act.
fr.setStr("1234567")

# Assert.
self.assertEqual(expected, fr.getStr())

def testIsEqual(self):
l = Fr()
l.setByCSPRNG()
self.assertTrue(l == l)
# Arrange.
fr = createRandomFr()

fr2 = Fr()
fr2.setStr(fr.getStr())

self.assertEqual(fr, fr2)

def testSetInt(self):
Fr().setInt(1)

def testMul(self):
Fr() * Fr()

def testGetStr(self):
fr = Fr()
fr.setStr(b"255")
fr.setStr("255")
s = fr.getStr()
self.assertEqual(b"255", s)
self.assertEqual("255", s)

def testByCSPRNG(self):
Fr().setByCSPRNG()

2 changes: 1 addition & 1 deletion tests/test_g1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from mcl import G1
from mcl import Fr

from . import test_data
from tests import test_data


class G1Tests(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_g2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from mcl import G2
from mcl import Fr

from . import test_data
from tests import test_data


class G2Tests(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mcl import G1
from mcl import Fr

from . import test_data
from tests import test_data


class GTTests(unittest.TestCase):
Expand Down