diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..8dd399a --- /dev/null +++ b/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 88 +extend-ignore = E203 diff --git a/mcl/bindings.py b/mcl/bindings.py new file mode 100644 index 0000000..c997ca3 --- /dev/null +++ b/mcl/bindings.py @@ -0,0 +1,51 @@ +import ctypes +import inspect + +from . import hook, utils + +MCL_NAME = "__mcl_name__" + + +def is_mcl_method(obj): + return inspect.ismethod(obj) and hasattr(obj, MCL_NAME) + + +def method_binding(method_name=None): + def decorator(method_obj): + setattr(method_obj, MCL_NAME, method_name or method_obj.__name__) + return method_obj + + return decorator + + +def _build_mcl_ctypes_binding(method_obj, mcl_method_name, klass_instance): + wrapper = utils.wrap_function( + hook.mclbls12_384, + mcl_method_name, + [ + ctypes.POINTER(klass_instance), + ctypes.c_char_p, + ctypes.c_size_t, + ctypes.c_int64, + ], + ) + return wrapper + + +def class_binding(): + def decorator(klass_instance): + method_name_obj_pairs = inspect.getmembers( + klass_instance, predicate=is_mcl_method + ) + + for method_name, method_obj in method_name_obj_pairs: + mcl_method_name = getattr(method_obj, MCL_NAME) + setattr( + klass_instance, + method_name, + _build_mcl_ctypes_binding(method_obj, mcl_method_name, klass_instance), + ) + + return klass_instance + + return decorator diff --git a/mcl/builder.py b/mcl/builder.py index 7c0cb57..6203111 100644 --- a/mcl/builder.py +++ b/mcl/builder.py @@ -2,13 +2,12 @@ import dataclasses import inspect -from . import hook -from . import utils +from . import hook, utils BUFFER_SIZE = 2048 -def tryGetBuilderMethodFromGlobals(method_name: str) -> callable: +def tryGetBuilderMethodFromGlobals(method_name: str): return globals().get("build" + method_name[0].upper() + method_name[1:]) @@ -51,7 +50,12 @@ def buildSetStr(cls): wrapper = utils.wrap_function( hook.mclbls12_384, f"mclBn{cls.__name__}_setStr", - [ctypes.POINTER(cls), ctypes.c_char_p, ctypes.c_size_t, ctypes.c_int64], + [ + ctypes.POINTER(cls), + ctypes.c_char_p, + ctypes.c_size_t, + ctypes.c_int64, + ], ) def setStr(self, value, mode=10): @@ -75,7 +79,9 @@ def setInt(self, value): def buildSetByCSPRNG(cls): wrapper = utils.wrap_function( - hook.mclbls12_384, f"mclBn{cls.__name__}_setByCSPRNG", [ctypes.POINTER(cls)], + hook.mclbls12_384, + f"mclBn{cls.__name__}_setByCSPRNG", + [ctypes.POINTER(cls)], ) def setByCSPRNG(self): @@ -120,7 +126,9 @@ def isEqual(self, other): def buildIsOne(cls): wrapper = utils.wrap_function( - hook.mclbls12_384, f"mclBn{cls.__name__}_isOne", [ctypes.POINTER(cls)], + hook.mclbls12_384, + f"mclBn{cls.__name__}_isOne", + [ctypes.POINTER(cls)], ) def isOne(self, other): @@ -248,7 +256,11 @@ def buildPairing(cls, left_group, right_group): wrapper = utils.wrap_function( hook.mclbls12_384, f"mclBn_pairing", - (ctypes.POINTER(cls), ctypes.POINTER(left_group), ctypes.POINTER(right_group)), + ( + ctypes.POINTER(cls), + ctypes.POINTER(left_group), + ctypes.POINTER(right_group), + ), ) @staticmethod diff --git a/mcl/structures/Fp.py b/mcl/structures/Fp.py index 9bb65f3..524cb3c 100644 --- a/mcl/structures/Fp.py +++ b/mcl/structures/Fp.py @@ -1,7 +1,6 @@ import ctypes -from .. import builder -from .. import consts +from .. import builder, consts @builder.provide_methods( @@ -23,3 +22,8 @@ ) class Fp(ctypes.Structure): _fields_ = [("v", ctypes.c_ulonglong * consts.FP_SIZE)] + + def __repr__(self): + import pdb + + pdb.set_trace() diff --git a/mcl/structures/Fr.py b/mcl/structures/Fr.py index 95618ff..f0a4063 100644 --- a/mcl/structures/Fr.py +++ b/mcl/structures/Fr.py @@ -1,25 +1,94 @@ import ctypes -from .. import builder -from .. import consts - - -@builder.provide_methods( - builder.method("__add__").using(builder.buildThreeOp).with_args("add"), - builder.method("__eq__").using(builder.buildIsEqual), - builder.method("__invert__").using(builder.buildTwoOp).with_args("inv"), - builder.method("__mul__").using(builder.buildThreeOp).with_args("mul"), - builder.method("__neg__").using(builder.buildTwoOp).with_args("neg"), - builder.method("__sub__").using(builder.buildThreeOp).with_args("sub"), - builder.method("__truediv__").using(builder.buildThreeOp).with_args("div"), - builder.method("deserialize"), - builder.method("getStr"), - builder.method("isOne"), - builder.method("isZero"), - builder.method("serialize"), - builder.method("setByCSPRNG"), - builder.method("setInt"), - builder.method("setStr"), -) +from .. import hook + +mclbn384_256 = hook.mclbls12_384 + + class Fr(ctypes.Structure): - _fields_ = [("v", ctypes.c_ulonglong * consts.FR_SIZE)] + _fields_ = [("v", ctypes.c_ulonglong * 6)] + + def __init__(self, value=None): + if value is None: + return + + if isinstance(value, str): + self.setStr(value) + elif isinstance(value, int): + self.setInt(value) + + def setInt(self, v): + mclbn384_256.mclBnFr_setInt(ctypes.byref(self.v), v) + + def setStr(self, value: str): + value = value if isinstance(value, bytes) else value.encode() + error = mclbn384_256.mclBnFr_setStr( + ctypes.byref(self.v), ctypes.c_char_p(value), len(value), 10 + ) + if error: + raise RuntimeError("mclBnFr_setStr failed") + + def getStr(self): + svLen = 2048 + sv = ctypes.create_string_buffer(b"\0" * svLen) + mclbn384_256.mclBnFr_getStr(sv, svLen, ctypes.byref(self.v), 10) + return sv.value.decode() + + def isZero(self): + return mclbn384_256.mclBnFr_isZero(ctypes.byref(self.v)) != 0 + + def isOne(self): + return mclbn384_256.mclBnFr_isOne(ctypes.byref(self.v)) != 0 + + def setByCSPRNG(self): + return mclbn384_256.mclBnFr_setByCSPRNG(ctypes.byref(self.v)) + + @classmethod + def random(cls): + value = cls() + value.setByCSPRNG() + return value + + def __eq__(self, rhs): + return ( + mclbn384_256.mclBnFr_isEqual(ctypes.byref(self.v), ctypes.byref(rhs.v)) != 0 + ) + + def __ne__(self, rhs): + return not (self == rhs) + + def __add__(self, rhs): + ret = Fr() + mclbn384_256.mclBnFr_add( + ctypes.byref(ret.v), ctypes.byref(self.v), ctypes.byref(rhs.v) + ) + return ret + + def __sub__(self, rhs): + ret = Fr() + mclbn384_256.mclBnFr_sub( + ctypes.byref(ret.v), ctypes.byref(self.v), ctypes.byref(rhs.v) + ) + return ret + + def __mul__(self, rhs): + ret = Fr() + mclbn384_256.mclBnFr_mul( + ctypes.byref(ret.v), ctypes.byref(self.v), ctypes.byref(rhs.v) + ) + return ret + + def __div__(self, rhs): + ret = Fr() + mclbn384_256.mclBnFr_div( + ctypes.byref(ret.v), ctypes.byref(self.v), ctypes.byref(rhs.v) + ) + return ret + + def __invert__(self): + ret = Fr() + mclbn384_256.mclBnFr_neg(ctypes.byref(ret.v), ctypes.byref(self.v)) + return ret + + def __repr__(self): + return f"Fr({self.getStr()})" diff --git a/mcl/structures/G1.py b/mcl/structures/G1.py index 69152fe..e260e2c 100644 --- a/mcl/structures/G1.py +++ b/mcl/structures/G1.py @@ -1,8 +1,7 @@ -import types import ctypes +import types -from .. import utils -from .. import builder +from .. import builder, utils from .Fp import Fp from .Fr import Fr diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6d6e6fb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +exclude = 'venv' diff --git a/tests/test_fr.py b/tests/test_fr.py index c3d9e81..641868e 100644 --- a/tests/test_fr.py +++ b/tests/test_fr.py @@ -3,29 +3,58 @@ from mcl import Fr +def createRandomFr() -> Fr: + fr = Fr() + fr.setByCSPRNG() + return fr + + class FrTests(unittest.TestCase): + + def testAdd(self): + self.assertEqual(Fr(30), Fr(10) + Fr(20)) + + def testEquals(self): + self.assertEqual(Fr(20), Fr(20)) + + def testInvert(self): + self.assertEqual(Fr(10), ~~Fr(10)) + + def testMul(self): + self.assertEqual(Fr(8), Fr(4) * Fr(2)) + def testInitFr(self): self.assertIsNotNone(Fr()) def testSetStr(self): - Fr().setStr(b"12345678901234567") + # Arrange. + expected = "1234567" + fr = Fr() + + # Act. + fr.setStr("1234567") + + # Assert. + self.assertEqual(expected, fr.getStr()) def testIsEqual(self): - l = Fr() - l.setByCSPRNG() - self.assertTrue(l == l) + # Arrange. + fr = createRandomFr() + + fr2 = Fr() + fr2.setStr(fr.getStr()) + + self.assertEqual(fr, fr2) def testSetInt(self): Fr().setInt(1) - def testMul(self): - Fr() * Fr() - def testGetStr(self): fr = Fr() - fr.setStr(b"255") + fr.setStr("255") s = fr.getStr() - self.assertEqual(b"255", s) + self.assertEqual("255", s) def testByCSPRNG(self): Fr().setByCSPRNG() + diff --git a/tests/test_g1.py b/tests/test_g1.py index 9faa9f4..df76bc4 100644 --- a/tests/test_g1.py +++ b/tests/test_g1.py @@ -3,7 +3,7 @@ from mcl import G1 from mcl import Fr -from . import test_data +from tests import test_data class G1Tests(unittest.TestCase): diff --git a/tests/test_g2.py b/tests/test_g2.py index 0045546..97a7827 100644 --- a/tests/test_g2.py +++ b/tests/test_g2.py @@ -3,7 +3,7 @@ from mcl import G2 from mcl import Fr -from . import test_data +from tests import test_data class G2Tests(unittest.TestCase): diff --git a/tests/test_gt.py b/tests/test_gt.py index 2b527cf..8cae55d 100644 --- a/tests/test_gt.py +++ b/tests/test_gt.py @@ -5,7 +5,7 @@ from mcl import G1 from mcl import Fr -from . import test_data +from tests import test_data class GTTests(unittest.TestCase):