From fbff7bfa7068293538586be3edbab99f29d84812 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 25 Dec 2025 22:39:45 +0800 Subject: [PATCH 1/3] refactor: clean up imports and remove unnecessary newlines --- brainpy/__init__.py | 11 ++-- brainpy/analysis/__init__.py | 1 - brainpy/analysis/highdim/slow_points.py | 2 +- brainpy/channels.py | 2 - brainpy/check.py | 1 - brainpy/connect/base.py | 2 +- brainpy/connect/custom_conn.py | 2 +- brainpy/context.py | 1 + brainpy/delay.py | 2 +- brainpy/dnn/__init__.py | 1 - brainpy/dnn/conv.py | 2 +- brainpy/dnn/linear.py | 8 +-- brainpy/dyn/base.py | 2 +- brainpy/dyn/channels/__init__.py | 1 - brainpy/dyn/channels/base.py | 2 +- brainpy/dyn/channels/potassium_calcium.py | 2 +- brainpy/dyn/ions/base.py | 5 +- brainpy/dyn/neurons/hh.py | 2 +- brainpy/dyn/others/input.py | 2 +- brainpy/dyn/outs/base.py | 2 +- brainpy/dyn/projections/align_post.py | 4 +- brainpy/dyn/projections/align_pre.py | 2 +- brainpy/dyn/projections/base.py | 2 +- brainpy/dyn/projections/delta.py | 2 +- brainpy/dyn/projections/inputs.py | 2 +- brainpy/dyn/projections/plasticity.py | 4 +- brainpy/dyn/projections/utils.py | 1 - brainpy/dyn/projections/vanilla.py | 2 +- brainpy/dyn/rates/populations.py | 10 +-- brainpy/dyn/rates/rnncells.py | 14 ++-- brainpy/dyn/synapses/abstract_models.py | 2 +- brainpy/dynold/__init__.py | 2 +- brainpy/dynold/experimental/__init__.py | 2 +- brainpy/dynold/neurons/biological_models.py | 8 +-- brainpy/dynold/neurons/reduced_models.py | 10 +-- brainpy/dynold/synapses/base.py | 4 +- brainpy/dynold/synapses/learning_rules.py | 2 +- brainpy/dynsys.py | 4 +- brainpy/inputs/currents.py | 5 +- brainpy/inputs/tests/test_currents.py | 2 - brainpy/integrators/fde/Caputo.py | 2 +- brainpy/integrators/ode/base.py | 2 +- brainpy/integrators/pde/__init__.py | 2 +- brainpy/integrators/runner.py | 2 +- brainpy/integrators/tests/test_joint_eq.py | 1 + brainpy/layers.py | 2 - brainpy/math/__init__.py | 3 +- brainpy/math/_utils.py | 4 -- brainpy/math/activations.py | 3 +- brainpy/math/defaults.py | 2 +- brainpy/math/delayvars.py | 2 +- brainpy/math/environment.py | 2 +- brainpy/math/jitconn/event_matvec.py | 4 +- brainpy/math/ndarray.py | 1 - brainpy/math/object_transform/__init__.py | 3 + brainpy/math/object_transform/_utils.py | 2 +- brainpy/math/object_transform/autograd.py | 1 + brainpy/math/object_transform/collectors.py | 2 +- brainpy/math/object_transform/controls.py | 66 +++++++++---------- brainpy/math/object_transform/jit.py | 16 ++--- .../object_transform/tests/test_collector.py | 1 - .../object_transform/tests/test_controls.py | 11 ++-- brainpy/math/sparse/utils.py | 2 - brainpy/math/surrogate/_utils.py | 2 +- brainpy/math/tests/test_numpy_ops.py | 4 +- brainpy/mixin.py | 3 +- brainpy/neurons.py | 60 ++++++++--------- brainpy/optim/optimizer.py | 2 +- brainpy/optim/scheduler.py | 4 +- brainpy/rates.py | 2 - brainpy/runners.py | 4 +- brainpy/running/runner.py | 2 +- brainpy/synouts.py | 3 +- brainpy/test_main.py | 18 +++++ brainpy/tools/dicts.py | 3 +- brainpy/tools/others.py | 1 - brainpy/train/back_propagation.py | 7 +- brainpy/train/offline.py | 4 +- brainpy/train/online.py | 4 +- brainpy/transform.py | 3 +- brainpy/visualization.py | 1 - 81 files changed, 196 insertions(+), 204 deletions(-) create mode 100644 brainpy/test_main.py diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 41c09ad55..2f84e37c3 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -17,15 +17,14 @@ __version__ = "2.7.5" __version_info__ = tuple(map(int, __version__.split("."))) - from brainpy import _errors as errors -from brainpy import mixin # fundamental supporting modules from brainpy import check, tools # Part: Math Foundation # # ----------------------- # # math foundation from brainpy import math +from brainpy import mixin # Part: Toolbox # # --------------- # # modules of toolbox @@ -109,11 +108,11 @@ # ---------------- # from brainpy.train.base import (DSTrainer as DSTrainer, ) from brainpy.train.back_propagation import (BPTT as BPTT, - BPFF as BPFF, ) + BPFF as BPFF, ) from brainpy.train.online import (OnlineTrainer as OnlineTrainer, - ForceTrainer as ForceTrainer, ) + ForceTrainer as ForceTrainer, ) from brainpy.train.offline import (OfflineTrainer as OfflineTrainer, - RidgeTrainer as RidgeTrainer, ) + RidgeTrainer as RidgeTrainer, ) # Part: Analysis # # ---------------- # @@ -152,8 +151,6 @@ except: pass - - if __name__ == '__main__': connect initialize, # weight initialization diff --git a/brainpy/analysis/__init__.py b/brainpy/analysis/__init__.py index 2dc37bfc3..76f3a1d27 100644 --- a/brainpy/analysis/__init__.py +++ b/brainpy/analysis/__init__.py @@ -34,4 +34,3 @@ from .highdim.slow_points import * from .lowdim.lowdim_bifurcation import * from .lowdim.lowdim_phase_plane import * - diff --git a/brainpy/analysis/highdim/slow_points.py b/brainpy/analysis/highdim/slow_points.py index f6cc02d22..24cb5abdd 100644 --- a/brainpy/analysis/highdim/slow_points.py +++ b/brainpy/analysis/highdim/slow_points.py @@ -25,8 +25,8 @@ from jax.scipy.optimize import minimize import brainpy.math as bm -from brainpy._errors import AnalyzerError, UnsupportedError from brainpy import optim, losses +from brainpy._errors import AnalyzerError, UnsupportedError from brainpy.analysis import utils, base, constants from brainpy.context import share from brainpy.deprecations import _input_deprecate_msg diff --git a/brainpy/channels.py b/brainpy/channels.py index 54d26beb3..d4157861b 100644 --- a/brainpy/channels.py +++ b/brainpy/channels.py @@ -17,11 +17,9 @@ This module has been deprecated since brainpy>=2.4.0. Use ``brainpy.dyn`` module instead. """ - from .dyn.channels import * from .dyn.ions import * - if __name__ == '__main__': IL Potassium diff --git a/brainpy/check.py b/brainpy/check.py index 4bf048e92..30dbc3174 100644 --- a/brainpy/check.py +++ b/brainpy/check.py @@ -647,4 +647,3 @@ def true_err_fun(arg, transforms): cond(remove_vmap(as_jax(pred)), lambda: jax.pure_callback(true_err_fun, None), lambda: None) - diff --git a/brainpy/connect/base.py b/brainpy/connect/base.py index c4aea96f5..8e7ad9260 100644 --- a/brainpy/connect/base.py +++ b/brainpy/connect/base.py @@ -20,8 +20,8 @@ import jax.numpy as jnp import numpy as onp -from brainpy._errors import ConnectorError from brainpy import tools, math as bm +from brainpy._errors import ConnectorError __all__ = [ # the connection types diff --git a/brainpy/connect/custom_conn.py b/brainpy/connect/custom_conn.py index bff088a56..4d751fc43 100644 --- a/brainpy/connect/custom_conn.py +++ b/brainpy/connect/custom_conn.py @@ -17,8 +17,8 @@ import jax.numpy as jnp import numpy as np -from brainpy._errors import ConnectorError from brainpy import math as bm, tools +from brainpy._errors import ConnectorError from .base import * __all__ = [ diff --git a/brainpy/context.py b/brainpy/context.py index d9d6db78a..0ef8c42c4 100644 --- a/brainpy/context.py +++ b/brainpy/context.py @@ -21,6 +21,7 @@ from typing import Any, Union import brainstate + from brainpy.math.defaults import env from brainpy.tools.dicts import DotDict diff --git a/brainpy/delay.py b/brainpy/delay.py index a283c5916..ea5e7c22f 100644 --- a/brainpy/delay.py +++ b/brainpy/delay.py @@ -24,13 +24,13 @@ import jax.numpy as jnp import numpy as np -from brainpy.mixin import ParamDesc, ReturnInfo, JointType, SupportAutoDelay from brainpy import check, math as bm from brainpy.check import jit_error from brainpy.context import share from brainpy.dynsys import DynamicalSystem from brainpy.initialize import variable_ from brainpy.math.delayvars import ROTATE_UPDATE, CONCAT_UPDATE +from brainpy.mixin import ParamDesc, ReturnInfo, JointType, SupportAutoDelay __all__ = [ 'Delay', diff --git a/brainpy/dnn/__init__.py b/brainpy/dnn/__init__.py index 750742574..1e549c564 100644 --- a/brainpy/dnn/__init__.py +++ b/brainpy/dnn/__init__.py @@ -22,4 +22,3 @@ from .linear import * from .normalization import * from .pooling import * - diff --git a/brainpy/dnn/conv.py b/brainpy/dnn/conv.py index 85fd54526..dc39b597f 100644 --- a/brainpy/dnn/conv.py +++ b/brainpy/dnn/conv.py @@ -15,9 +15,9 @@ # ============================================================================== from typing import Union, Tuple, Optional, Sequence, Callable +import brainstate from jax import lax -import brainstate from brainpy import math as bm, tools from brainpy.dnn.base import Layer from brainpy.initialize import Initializer, XavierNormal, ZeroInit, parameter diff --git a/brainpy/dnn/linear.py b/brainpy/dnn/linear.py index 5e3dd494e..87dd0435c 100644 --- a/brainpy/dnn/linear.py +++ b/brainpy/dnn/linear.py @@ -19,18 +19,18 @@ import jax import jax.numpy as jnp import numpy as np -from brainevent._csr_impl_plasticity import csr_on_pre, csr2csc_on_post -from brainevent._dense_impl_plasticity import dense_on_pre, dense_on_post +from brainevent import csr_on_pre, csr2csc_on_post +from brainevent import dense_on_pre, dense_on_post -from brainpy._errors import MathError -from brainpy.mixin import SupportOnline, SupportOffline, SupportSTDP from brainpy import connect, initialize as init from brainpy import math as bm +from brainpy._errors import MathError from brainpy.check import is_initializer from brainpy.connect import csr2csc from brainpy.context import share from brainpy.dnn.base import Layer from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter +from brainpy.mixin import SupportOnline, SupportOffline, SupportSTDP from brainpy.types import ArrayType, Sharding __all__ = [ diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index a9f1fce14..d7b0a0cd8 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from brainpy.mixin import SupportAutoDelay, ParamDesc from brainpy.dynsys import Dynamic +from brainpy.mixin import SupportAutoDelay, ParamDesc __all__ = [ 'NeuDyn', 'SynDyn', 'IonChaDyn', diff --git a/brainpy/dyn/channels/__init__.py b/brainpy/dyn/channels/__init__.py index 682fcdbcd..38e2d3510 100644 --- a/brainpy/dyn/channels/__init__.py +++ b/brainpy/dyn/channels/__init__.py @@ -23,4 +23,3 @@ from .potassium_compatible import * from .sodium import * from .sodium_compatible import * - diff --git a/brainpy/dyn/channels/base.py b/brainpy/dyn/channels/base.py index fb4a1da9d..add9e403c 100644 --- a/brainpy/dyn/channels/base.py +++ b/brainpy/dyn/channels/base.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from brainpy.mixin import TreeNode from brainpy.dyn.base import IonChaDyn from brainpy.dyn.neurons.hh import HHTypedNeuron +from brainpy.mixin import TreeNode __all__ = [ 'IonChannel', diff --git a/brainpy/dyn/channels/potassium_calcium.py b/brainpy/dyn/channels/potassium_calcium.py index d1fd13a59..858c73757 100644 --- a/brainpy/dyn/channels/potassium_calcium.py +++ b/brainpy/dyn/channels/potassium_calcium.py @@ -20,12 +20,12 @@ from typing import Union, Callable, Optional import brainpy.math as bm -from brainpy.mixin import JointType from brainpy.context import share from brainpy.dyn.ions.calcium import Calcium from brainpy.dyn.ions.potassium import Potassium from brainpy.initialize import Initializer, parameter, variable from brainpy.integrators.ode.generic import odeint +from brainpy.mixin import JointType from brainpy.types import Shape, ArrayType from .calcium import CalciumChannel from .potassium import PotassiumChannel diff --git a/brainpy/dyn/ions/base.py b/brainpy/dyn/ions/base.py index 0d0e3ee9a..3a278c217 100644 --- a/brainpy/dyn/ions/base.py +++ b/brainpy/dyn/ions/base.py @@ -15,12 +15,13 @@ # ============================================================================== from typing import Union, Optional, Dict, Sequence, Callable +from brainstate.mixin import _JointGenericAlias + import brainpy.math as bm -from brainpy.mixin import Container, TreeNode from brainpy.dyn.base import IonChaDyn from brainpy.dyn.neurons.hh import HHTypedNeuron +from brainpy.mixin import Container, TreeNode from brainpy.types import Shape -from brainstate.mixin import _JointGenericAlias __all__ = [ 'MixIons', diff --git a/brainpy/dyn/neurons/hh.py b/brainpy/dyn/neurons/hh.py index f7d6b7241..77b6d91a3 100644 --- a/brainpy/dyn/neurons/hh.py +++ b/brainpy/dyn/neurons/hh.py @@ -17,7 +17,6 @@ from typing import Union, Callable, Optional import brainpy.math as bm -from brainpy.mixin import Container, TreeNode from brainpy.check import is_initializer from brainpy.context import share from brainpy.dyn.base import NeuDyn, IonChaDyn @@ -25,6 +24,7 @@ from brainpy.initialize import Uniform, variable_, noise as init_noise from brainpy.integrators import JointEq from brainpy.integrators import odeint, sdeint +from brainpy.mixin import Container, TreeNode from brainpy.types import ArrayType from brainpy.types import Shape diff --git a/brainpy/dyn/others/input.py b/brainpy/dyn/others/input.py index d8732fdd8..1272eaa5a 100644 --- a/brainpy/dyn/others/input.py +++ b/brainpy/dyn/others/input.py @@ -19,12 +19,12 @@ import jax import jax.numpy as jnp -from brainpy.mixin import ReturnInfo from brainpy import math as bm from brainpy.context import share from brainpy.dyn.base import NeuDyn from brainpy.dyn.utils import get_spk_type from brainpy.initialize import parameter, variable_ +from brainpy.mixin import ReturnInfo from brainpy.types import Shape, ArrayType __all__ = [ diff --git a/brainpy/dyn/outs/base.py b/brainpy/dyn/outs/base.py index 3c9571b07..baaa435e8 100644 --- a/brainpy/dyn/outs/base.py +++ b/brainpy/dyn/outs/base.py @@ -15,8 +15,8 @@ from typing import Optional import brainpy.math as bm -from brainpy.mixin import ParamDesc, BindCondData from brainpy.dynsys import DynamicalSystem +from brainpy.mixin import ParamDesc, BindCondData __all__ = [ 'SynOut' diff --git a/brainpy/dyn/projections/align_post.py b/brainpy/dyn/projections/align_post.py index 9f8f55b21..af192cc68 100644 --- a/brainpy/dyn/projections/align_post.py +++ b/brainpy/dyn/projections/align_post.py @@ -14,11 +14,11 @@ # ============================================================================== from typing import Optional, Callable, Union -from brainpy.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData, AlignPost) from brainpy import math as bm, check from brainpy.delay import (delay_identifier, - register_delay_by_return) + register_delay_by_return) from brainpy.dynsys import DynamicalSystem, Projection +from brainpy.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData, AlignPost) __all__ = [ 'HalfProjAlignPostMg', 'FullProjAlignPostMg', diff --git a/brainpy/dyn/projections/align_pre.py b/brainpy/dyn/projections/align_pre.py index 171a996fa..13a8f7d51 100644 --- a/brainpy/dyn/projections/align_pre.py +++ b/brainpy/dyn/projections/align_pre.py @@ -14,10 +14,10 @@ # ============================================================================== from typing import Optional, Union -from brainpy.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData) from brainpy import math as bm, check from brainpy.delay import (Delay, DelayAccess, init_delay_by_return, register_delay_by_return) from brainpy.dynsys import DynamicalSystem, Projection +from brainpy.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData) from .utils import _get_return __all__ = [ diff --git a/brainpy/dyn/projections/base.py b/brainpy/dyn/projections/base.py index f50c118b4..9bdca17d2 100644 --- a/brainpy/dyn/projections/base.py +++ b/brainpy/dyn/projections/base.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from brainpy.mixin import ReturnInfo from brainpy import math as bm +from brainpy.mixin import ReturnInfo def _get_return(return_info): diff --git a/brainpy/dyn/projections/delta.py b/brainpy/dyn/projections/delta.py index 3c968a796..fdcb5cde2 100644 --- a/brainpy/dyn/projections/delta.py +++ b/brainpy/dyn/projections/delta.py @@ -14,10 +14,10 @@ # ============================================================================== from typing import Optional, Union -from brainpy.mixin import (JointType, SupportAutoDelay) from brainpy import math as bm, check from brainpy.delay import (delay_identifier, register_delay_by_return) from brainpy.dynsys import DynamicalSystem, Projection +from brainpy.mixin import (JointType, SupportAutoDelay) __all__ = [ 'HalfProjDelta', 'FullProjDelta', diff --git a/brainpy/dyn/projections/inputs.py b/brainpy/dyn/projections/inputs.py index 76d1ff91c..5c78ce210 100644 --- a/brainpy/dyn/projections/inputs.py +++ b/brainpy/dyn/projections/inputs.py @@ -16,11 +16,11 @@ from typing import Any from typing import Union, Optional -from brainpy.mixin import SupportAutoDelay from brainpy import check, math as bm from brainpy.context import share from brainpy.dynsys import Dynamic from brainpy.dynsys import Projection +from brainpy.mixin import SupportAutoDelay from brainpy.types import Shape __all__ = [ diff --git a/brainpy/dyn/projections/plasticity.py b/brainpy/dyn/projections/plasticity.py index fecb490d9..c6ccdd823 100644 --- a/brainpy/dyn/projections/plasticity.py +++ b/brainpy/dyn/projections/plasticity.py @@ -14,12 +14,12 @@ # ============================================================================== from typing import Optional, Callable, Union -from brainpy.mixin import (JointType, ParamDescriber, SupportAutoDelay, - BindCondData, AlignPost, SupportSTDP) from brainpy import math as bm, check from brainpy.delay import register_delay_by_return from brainpy.dyn.synapses.abstract_models import Expon from brainpy.dynsys import DynamicalSystem, Projection +from brainpy.mixin import (JointType, ParamDescriber, SupportAutoDelay, + BindCondData, AlignPost, SupportSTDP) from brainpy.types import ArrayType from .align_post import (align_post_add_bef_update, ) from .align_pre import (align_pre2_add_bef_update, ) diff --git a/brainpy/dyn/projections/utils.py b/brainpy/dyn/projections/utils.py index f50c118b4..4ddfbdb57 100644 --- a/brainpy/dyn/projections/utils.py +++ b/brainpy/dyn/projections/utils.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== from brainpy.mixin import ReturnInfo -from brainpy import math as bm def _get_return(return_info): diff --git a/brainpy/dyn/projections/vanilla.py b/brainpy/dyn/projections/vanilla.py index 9c70aed5c..f35ac7622 100644 --- a/brainpy/dyn/projections/vanilla.py +++ b/brainpy/dyn/projections/vanilla.py @@ -14,9 +14,9 @@ # ============================================================================== from typing import Optional -from brainpy.mixin import (JointType, BindCondData) from brainpy import math as bm, check from brainpy.dynsys import DynamicalSystem, Projection +from brainpy.mixin import (JointType, BindCondData) __all__ = [ 'VanillaProj', diff --git a/brainpy/dyn/rates/populations.py b/brainpy/dyn/rates/populations.py index 71486a681..c774c5fd0 100644 --- a/brainpy/dyn/rates/populations.py +++ b/brainpy/dyn/rates/populations.py @@ -23,11 +23,11 @@ from brainpy.dyn.base import NeuDyn from brainpy.dyn.others.noise import OUProcess from brainpy.initialize import (Initializer, - Uniform, - parameter, - variable, - variable_, - ZeroInit) + Uniform, + parameter, + variable, + variable_, + ZeroInit) from brainpy.integrators.joint_eq import JointEq from brainpy.integrators.ode.generic import odeint from brainpy.types import Shape, ArrayType diff --git a/brainpy/dyn/rates/rnncells.py b/brainpy/dyn/rates/rnncells.py index 30338be80..ef4107319 100644 --- a/brainpy/dyn/rates/rnncells.py +++ b/brainpy/dyn/rates/rnncells.py @@ -19,16 +19,16 @@ import brainpy.math as bm from brainpy.check import (is_integer, - is_initializer) + is_initializer) from brainpy.dnn.base import Layer from brainpy.dnn.conv import _GeneralConv from brainpy.initialize import (XavierNormal, - ZeroInit, - Orthogonal, - parameter, - variable, - variable_, - Initializer) + ZeroInit, + Orthogonal, + parameter, + variable, + variable_, + Initializer) from brainpy.math import activations from brainpy.types import ArrayType diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index a2215e768..0390fd2a7 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -14,7 +14,6 @@ # ============================================================================== from typing import Union, Sequence, Callable, Optional -from brainpy.mixin import AlignPost, ReturnInfo from brainpy import math as bm from brainpy.context import share from brainpy.dyn import _docs @@ -22,6 +21,7 @@ from brainpy.initialize import parameter from brainpy.integrators.joint_eq import JointEq from brainpy.integrators.ode.generic import odeint +from brainpy.mixin import AlignPost, ReturnInfo from brainpy.types import ArrayType __all__ = [ diff --git a/brainpy/dynold/__init__.py b/brainpy/dynold/__init__.py index 2c035cd2b..19a103570 100644 --- a/brainpy/dynold/__init__.py +++ b/brainpy/dynold/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== \ No newline at end of file +# ============================================================================== diff --git a/brainpy/dynold/experimental/__init__.py b/brainpy/dynold/experimental/__init__.py index 2c035cd2b..19a103570 100644 --- a/brainpy/dynold/experimental/__init__.py +++ b/brainpy/dynold/experimental/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== \ No newline at end of file +# ============================================================================== diff --git a/brainpy/dynold/neurons/biological_models.py b/brainpy/dynold/neurons/biological_models.py index 12fb803c8..ddaca03da 100644 --- a/brainpy/dynold/neurons/biological_models.py +++ b/brainpy/dynold/neurons/biological_models.py @@ -21,10 +21,10 @@ from brainpy.dyn.base import NeuDyn from brainpy.dyn.neurons import hh from brainpy.initialize import (OneInit, - Initializer, - parameter, - noise as init_noise, - variable_) + Initializer, + parameter, + noise as init_noise, + variable_) from brainpy.integrators.joint_eq import JointEq from brainpy.integrators.ode.generic import odeint from brainpy.integrators.sde.generic import sdeint diff --git a/brainpy/dynold/neurons/reduced_models.py b/brainpy/dynold/neurons/reduced_models.py index df5ae9d46..84a51f952 100644 --- a/brainpy/dynold/neurons/reduced_models.py +++ b/brainpy/dynold/neurons/reduced_models.py @@ -23,11 +23,11 @@ from brainpy.dyn.base import NeuDyn from brainpy.dyn.neurons import lif from brainpy.initialize import (ZeroInit, - OneInit, - Initializer, - parameter, - variable_, - noise as init_noise) + OneInit, + Initializer, + parameter, + variable_, + noise as init_noise) from brainpy.integrators import sdeint, odeint, JointEq from brainpy.types import Shape, ArrayType diff --git a/brainpy/dynold/synapses/base.py b/brainpy/dynold/synapses/base.py index 23f1e5611..eaf3ad6af 100644 --- a/brainpy/dynold/synapses/base.py +++ b/brainpy/dynold/synapses/base.py @@ -17,15 +17,15 @@ import jax -from brainpy._errors import UnsupportedError -from brainpy.mixin import (ParamDesc, JointType, SupportAutoDelay, BindCondData, ReturnInfo) from brainpy import math as bm +from brainpy._errors import UnsupportedError from brainpy.connect import TwoEndConnector, One2One, All2All from brainpy.dnn import linear from brainpy.dyn.base import NeuDyn from brainpy.dyn.projections.conn import SynConn from brainpy.dynsys import DynamicalSystem from brainpy.initialize import parameter +from brainpy.mixin import (ParamDesc, JointType, SupportAutoDelay, BindCondData, ReturnInfo) from brainpy.types import ArrayType __all__ = [ diff --git a/brainpy/dynold/synapses/learning_rules.py b/brainpy/dynold/synapses/learning_rules.py index 1c4738f37..b413c7bb3 100644 --- a/brainpy/dynold/synapses/learning_rules.py +++ b/brainpy/dynold/synapses/learning_rules.py @@ -15,7 +15,6 @@ # ============================================================================== from typing import Union, Dict, Callable, Optional -from brainpy.mixin import ParamDesc from brainpy.connect import TwoEndConnector from brainpy.dyn import synapses from brainpy.dyn.base import NeuDyn @@ -23,6 +22,7 @@ from brainpy.dynold.synouts import CUBA from brainpy.dynsys import Sequential from brainpy.initialize import Initializer +from brainpy.mixin import ParamDesc from brainpy.types import ArrayType __all__ = [ diff --git a/brainpy/dynsys.py b/brainpy/dynsys.py index eb6e1807e..6d56e3915 100644 --- a/brainpy/dynsys.py +++ b/brainpy/dynsys.py @@ -22,13 +22,13 @@ import jax import numpy as np -from brainpy._errors import NoImplementationError, UnsupportedError -from brainpy.mixin import SupportAutoDelay, Container, SupportInputProj, _get_delay_tool, MixIn from brainpy import tools, math as bm +from brainpy._errors import NoImplementationError, UnsupportedError from brainpy.context import share from brainpy.deprecations import _update_deprecate_msg from brainpy.initialize import parameter, variable_ from brainpy.math.object_transform.naming import get_unique_name +from brainpy.mixin import SupportAutoDelay, Container, SupportInputProj, _get_delay_tool, MixIn from brainpy.types import ArrayType, Shape __all__ = [ diff --git a/brainpy/inputs/currents.py b/brainpy/inputs/currents.py index 6f313700a..2b40686f0 100644 --- a/brainpy/inputs/currents.py +++ b/brainpy/inputs/currents.py @@ -15,9 +15,9 @@ # ============================================================================== import warnings +import brainstate import braintools -import brainstate import brainpy.math __all__ = [ @@ -294,4 +294,5 @@ def square_input(amplitude, frequency, duration, dt=None, bias=False, t_start=0. has a positive DC bias, thus non-negative (True). """ with brainstate.environ.context(dt=brainpy.math.get_dt() if dt is None else dt): - return braintools.input.square(amplitude, frequency, duration, t_start=t_start, t_end=t_end, duty_cycle=0.5, bias=bias) + return braintools.input.square(amplitude, frequency, duration, t_start=t_start, t_end=t_end, duty_cycle=0.5, + bias=bias) diff --git a/brainpy/inputs/tests/test_currents.py b/brainpy/inputs/tests/test_currents.py index bca7e9c7c..329d88cbe 100644 --- a/brainpy/inputs/tests/test_currents.py +++ b/brainpy/inputs/tests/test_currents.py @@ -15,7 +15,6 @@ # ============================================================================== from unittest import TestCase -import brainunit as u import numpy as np import brainpy as bp @@ -38,7 +37,6 @@ def show(current, duration, title=''): class TestCurrents(TestCase): - def test_section_input(self): current1, duration = bp.inputs.section_input(values=[0, 1., 0.], durations=[100, 300, 100], diff --git a/brainpy/integrators/fde/Caputo.py b/brainpy/integrators/fde/Caputo.py index 58e634b97..0f48fc586 100644 --- a/brainpy/integrators/fde/Caputo.py +++ b/brainpy/integrators/fde/Caputo.py @@ -24,8 +24,8 @@ from scipy.special import gamma, rgamma import brainpy.math as bm -from brainpy._errors import UnsupportedError from brainpy import check +from brainpy._errors import UnsupportedError from brainpy.integrators.constants import DT from brainpy.integrators.utils import check_inits, format_args from brainpy.types import ArrayType diff --git a/brainpy/integrators/ode/base.py b/brainpy/integrators/ode/base.py index ac1d38ee0..22ed7c8d5 100644 --- a/brainpy/integrators/ode/base.py +++ b/brainpy/integrators/ode/base.py @@ -15,8 +15,8 @@ # ============================================================================== from typing import Dict, Callable, Union -from brainpy._errors import DiffEqError, CodeError from brainpy import math as bm +from brainpy._errors import DiffEqError, CodeError from brainpy.check import is_dict_data from brainpy.integrators import constants, utils from brainpy.integrators.base import Integrator diff --git a/brainpy/integrators/pde/__init__.py b/brainpy/integrators/pde/__init__.py index 6ca5623cc..c5fc2d7fe 100644 --- a/brainpy/integrators/pde/__init__.py +++ b/brainpy/integrators/pde/__init__.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== \ No newline at end of file +# ============================================================================== diff --git a/brainpy/integrators/runner.py b/brainpy/integrators/runner.py index b29911938..54c78a9f2 100644 --- a/brainpy/integrators/runner.py +++ b/brainpy/integrators/runner.py @@ -24,8 +24,8 @@ import tqdm.auto from jax.tree_util import tree_flatten -from brainpy._errors import RunningError from brainpy import math as bm +from brainpy._errors import RunningError from brainpy.math.object_transform.base import Collector from brainpy.running.runner import Runner from .base import Integrator diff --git a/brainpy/integrators/tests/test_joint_eq.py b/brainpy/integrators/tests/test_joint_eq.py index e617ab814..8f9fe7960 100644 --- a/brainpy/integrators/tests/test_joint_eq.py +++ b/brainpy/integrators/tests/test_joint_eq.py @@ -155,6 +155,7 @@ def dv(v, t, x): def test_second_order_ode_wrong_signature(self): """Test that wrong signature gives helpful error message""" + # WRONG: both x and v before t in dx function def dx_wrong(x, v, t): return v diff --git a/brainpy/layers.py b/brainpy/layers.py index 91dfe4cb3..b2abddbd4 100644 --- a/brainpy/layers.py +++ b/brainpy/layers.py @@ -16,9 +16,7 @@ This module has been deprecated since brainpy>=2.4.0. Use ``brainpy.dnn`` module instead. """ - from .dnn import * if __name__ == '__main__': Dropout - diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index b8f512272..bc6382190 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -40,9 +40,8 @@ # the index update is the same way with the numpy # -import braintools - import brainstate +import braintools random = brainstate.random surrogate = braintools.surrogate diff --git a/brainpy/math/_utils.py b/brainpy/math/_utils.py index 3058d2add..02ea058d6 100644 --- a/brainpy/math/_utils.py +++ b/brainpy/math/_utils.py @@ -71,7 +71,3 @@ def new_fun(*args, **kwargs): ) return new_fun - - - - diff --git a/brainpy/math/activations.py b/brainpy/math/activations.py index f10998bed..d8cedd35d 100644 --- a/brainpy/math/activations.py +++ b/brainpy/math/activations.py @@ -30,8 +30,8 @@ import jax.numpy as jnp import jax.scipy import numpy as np - from brainstate.random import uniform + from .ndarray import Array __all__ = [ @@ -70,6 +70,7 @@ tanh = u.math.tanh + def get(activation): global_vars = globals() diff --git a/brainpy/math/defaults.py b/brainpy/math/defaults.py index 6ffb50146..067c8f895 100644 --- a/brainpy/math/defaults.py +++ b/brainpy/math/defaults.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import brainstate import jax.numpy as jnp from jax import config -import brainstate from .modes import NonBatchingMode from .scales import IdScaling diff --git a/brainpy/math/delayvars.py b/brainpy/math/delayvars.py index 9654467f9..f9ff8fb9e 100644 --- a/brainpy/math/delayvars.py +++ b/brainpy/math/delayvars.py @@ -21,8 +21,8 @@ from jax import vmap from jax.lax import stop_gradient -from brainpy._errors import UnsupportedError from brainpy import check +from brainpy._errors import UnsupportedError from brainpy.check import is_float, is_integer, jit_error from .compat_numpy import broadcast_to, expand_dims, concatenate from .environment import get_dt, get_float diff --git a/brainpy/math/environment.py b/brainpy/math/environment.py index bfc0e6e68..e9cb6ac9e 100644 --- a/brainpy/math/environment.py +++ b/brainpy/math/environment.py @@ -22,11 +22,11 @@ import warnings from typing import Any, Callable, TypeVar, cast +import brainstate.environ import jax from jax import config, numpy as jnp, devices from jax.lib import xla_bridge -import brainstate.environ from . import modes from . import scales from .defaults import defaults diff --git a/brainpy/math/jitconn/event_matvec.py b/brainpy/math/jitconn/event_matvec.py index 80eea34e0..b22fcb00b 100644 --- a/brainpy/math/jitconn/event_matvec.py +++ b/brainpy/math/jitconn/event_matvec.py @@ -20,8 +20,8 @@ import numpy as np from brainpy.math.jitconn.matvec import (mv_prob_homo, - mv_prob_uniform, - mv_prob_normal) + mv_prob_uniform, + mv_prob_normal) from brainpy.math.ndarray import Array as Array __all__ = [ diff --git a/brainpy/math/ndarray.py b/brainpy/math/ndarray.py index db6a33a66..7e486650f 100644 --- a/brainpy/math/ndarray.py +++ b/brainpy/math/ndarray.py @@ -232,7 +232,6 @@ def fill_(self, fill_value): return self - setattr(Array, "__array_priority__", 100) JaxArray = Array diff --git a/brainpy/math/object_transform/__init__.py b/brainpy/math/object_transform/__init__.py index b6fbdd531..eecc85fa6 100644 --- a/brainpy/math/object_transform/__init__.py +++ b/brainpy/math/object_transform/__init__.py @@ -39,3 +39,6 @@ from .jit import * from .naming import * from .variables import * + +if __name__ == '__main__': + ProgressBar diff --git a/brainpy/math/object_transform/_utils.py b/brainpy/math/object_transform/_utils.py index 191c72252..997f8b1bb 100644 --- a/brainpy/math/object_transform/_utils.py +++ b/brainpy/math/object_transform/_utils.py @@ -16,9 +16,9 @@ from functools import wraps from typing import Dict +import brainstate import jax.tree -import brainstate from .base import BrainPyObject, ArrayCollector __all__ = [ diff --git a/brainpy/math/object_transform/autograd.py b/brainpy/math/object_transform/autograd.py index 95ff403e1..a64b675c8 100644 --- a/brainpy/math/object_transform/autograd.py +++ b/brainpy/math/object_transform/autograd.py @@ -16,6 +16,7 @@ from typing import Union, Callable, Dict, Sequence, Optional import brainstate.transform + from ._utils import warp_to_no_state_input_output from .variables import Variable diff --git a/brainpy/math/object_transform/collectors.py b/brainpy/math/object_transform/collectors.py index 23a742a17..010c8b797 100644 --- a/brainpy/math/object_transform/collectors.py +++ b/brainpy/math/object_transform/collectors.py @@ -14,9 +14,9 @@ # ============================================================================== from typing import Sequence, Dict, Union +from brainstate._compatible_import import safe_zip from jax.tree_util import register_pytree_node -from brainstate._compatible_import import safe_zip from .variables import Variable __all__ = [ diff --git a/brainpy/math/object_transform/controls.py b/brainpy/math/object_transform/controls.py index c1c75b051..26c0f2aea 100644 --- a/brainpy/math/object_transform/controls.py +++ b/brainpy/math/object_transform/controls.py @@ -14,12 +14,12 @@ # limitations under the License. # ============================================================================== import numbers -from typing import Union, Sequence, Any, Dict, Callable, Optional +from typing import Union, Sequence, Any, Callable, Optional +import brainstate import jax import jax.numpy as jnp -import brainstate from brainpy.math.ndarray import Array from ._utils import warp_to_no_state_input_output @@ -35,37 +35,37 @@ def _convert_progress_bar_to_pbar( progress_bar: Union[bool, brainstate.transform.ProgressBar, int, None] ) -> Optional[brainstate.transform.ProgressBar]: - """Convert progress_bar parameter to brainstate pbar format. - - Parameters - ---------- - progress_bar : bool, ProgressBar, int, None - The progress_bar parameter value. - - Returns - ------- - pbar : ProgressBar or None - The converted ProgressBar instance or None. - - Raises - ------ - TypeError - If progress_bar is not a valid type. - """ - if progress_bar is False or progress_bar is None: - return None - elif progress_bar is True: - return brainstate.transform.ProgressBar() - elif isinstance(progress_bar, int): - # Support brainstate convention: int means freq parameter - return brainstate.transform.ProgressBar(freq=progress_bar) - elif isinstance(progress_bar, brainstate.transform.ProgressBar): - return progress_bar - else: - raise TypeError( - f"progress_bar must be bool, int, or ProgressBar instance, " - f"got {type(progress_bar).__name__}" - ) + """Convert progress_bar parameter to brainstate pbar format. + + Parameters + ---------- + progress_bar : bool, ProgressBar, int, None + The progress_bar parameter value. + + Returns + ------- + pbar : ProgressBar or None + The converted ProgressBar instance or None. + + Raises + ------ + TypeError + If progress_bar is not a valid type. + """ + if progress_bar is False or progress_bar is None: + return None + elif progress_bar is True: + return brainstate.transform.ProgressBar() + elif isinstance(progress_bar, int): + # Support brainstate convention: int means freq parameter + return brainstate.transform.ProgressBar(freq=progress_bar) + elif isinstance(progress_bar, brainstate.transform.ProgressBar): + return progress_bar + else: + raise TypeError( + f"progress_bar must be bool, int, or ProgressBar instance, " + f"got {type(progress_bar).__name__}" + ) def cond( diff --git a/brainpy/math/object_transform/jit.py b/brainpy/math/object_transform/jit.py index 0c60650fb..0ce02f40b 100644 --- a/brainpy/math/object_transform/jit.py +++ b/brainpy/math/object_transform/jit.py @@ -20,12 +20,12 @@ """ -from typing import Callable, Union, Optional, Sequence, Any, Iterable - -import jax.tree +from typing import Callable, Union, Sequence, Iterable import brainstate.transform +import jax.tree from brainstate.typing import Missing + from ._utils import warp_to_no_state_input_output __all__ = [ @@ -88,7 +88,6 @@ def jit( donate_argnums: Union[int, Sequence[int]] = (), inline: bool = False, keep_unused: bool = False, - abstracted_axes: Optional[Any] = None, # others **kwargs, ) -> Union[Callable, Callable[..., Callable]]: @@ -150,7 +149,6 @@ def jit( static_argnames=static_argnames, inline=inline, keep_unused=keep_unused, - abstracted_axes=abstracted_axes, **kwargs ) @@ -164,7 +162,6 @@ def cls_jit( static_argnames: Union[str, Iterable[str], None] = None, inline: bool = False, keep_unused: bool = False, - abstracted_axes: Optional[Any] = None, **kwargs ) -> Callable: """Just-in-time compile a function and then the jitted function as the bound method for a class. @@ -212,10 +209,9 @@ def cls_jit( return jit( func=func, static_argnums=static_argnums, - static_argnames=static_argnames, - inline=inline, - keep_unused=keep_unused, - abstracted_axes=abstracted_axes, + static_argnames=static_argnames, + inline=inline, + keep_unused=keep_unused, **kwargs ) diff --git a/brainpy/math/object_transform/tests/test_collector.py b/brainpy/math/object_transform/tests/test_collector.py index 1ba7da7fc..025060c12 100644 --- a/brainpy/math/object_transform/tests/test_collector.py +++ b/brainpy/math/object_transform/tests/test_collector.py @@ -283,4 +283,3 @@ def test_net_vars_2(): print() pprint(list(net.nodes(method='relative').keys())) # assert len(net.nodes(method='relative')) == 6 - diff --git a/brainpy/math/object_transform/tests/test_controls.py b/brainpy/math/object_transform/tests/test_controls.py index ba1f86ec8..66446d136 100644 --- a/brainpy/math/object_transform/tests/test_controls.py +++ b/brainpy/math/object_transform/tests/test_controls.py @@ -147,13 +147,13 @@ def test_for_loop_jit_false(self): """Test that jit=False disables JIT compilation""" a = bm.Variable(bm.zeros(1)) call_count = {'count': 0} - + def body(x): # This side effect should be visible when jit=False call_count['count'] += 1 a.value += x return a.value - + # Test with jit=False - should execute eagerly a.value = bm.zeros(1) call_count['count'] = 0 @@ -161,7 +161,7 @@ def body(x): # With jit=False, the function should be called 3 times self.assertEqual(call_count['count'], 3) self.assertTrue(bm.allclose(a.value, 3.)) - + def test_for_loop_jit_default(self): """Test that default behavior (jit=None) allows JIT compilation""" a = bm.Variable(bm.zeros(1)) @@ -191,9 +191,9 @@ def body(x): result = bm.for_loop(body, operands=bm.arange(0), jit=False) # Check that our specific warning was issued zero_length_warnings = [warning for warning in w - if "zero-length input" in str(warning.message)] + if "zero-length input" in str(warning.message)] self.assertGreaterEqual(len(zero_length_warnings), 1, - "Expected at least one zero-length input warning") + "Expected at least one zero-length input warning") # Variable should not have changed self.assertTrue(bm.allclose(a.value, 0.)) @@ -309,6 +309,7 @@ def f(carry, x): def test_scan_progress_bar_invalid_type(self): """Test that invalid progress_bar types raise TypeError""" + def f(carry, x): return carry + x, carry diff --git a/brainpy/math/sparse/utils.py b/brainpy/math/sparse/utils.py index 9b40d523f..ddab78677 100644 --- a/brainpy/math/sparse/utils.py +++ b/brainpy/math/sparse/utils.py @@ -62,5 +62,3 @@ def csr_to_coo( csr_to_dense = csr_todense - - diff --git a/brainpy/math/surrogate/_utils.py b/brainpy/math/surrogate/_utils.py index 7358e24f4..1fdbaad75 100644 --- a/brainpy/math/surrogate/_utils.py +++ b/brainpy/math/surrogate/_utils.py @@ -20,8 +20,8 @@ import jax -from brainpy._errors import UnsupportedError from brainpy import check +from brainpy._errors import UnsupportedError from brainpy.math.ndarray import Array as Array __all__ = [ diff --git a/brainpy/math/tests/test_numpy_ops.py b/brainpy/math/tests/test_numpy_ops.py index d06235004..8e7fa2b53 100644 --- a/brainpy/math/tests/test_numpy_ops.py +++ b/brainpy/math/tests/test_numpy_ops.py @@ -5398,7 +5398,7 @@ def testMgrid(self): np_mgrid = _indexer_with_default_outputs(np.mgrid) assertAllEqual = partial(self.assertAllClose, atol=0, rtol=0) assertAllEqual(np_mgrid[:4], bm.mgrid[:4]) - assertAllEqual(np_mgrid[:4, ], bm.mgrid[:4, ]) + assertAllEqual(np_mgrid[:4,], bm.mgrid[:4,]) assertAllEqual(np_mgrid[:4], jax.jit(lambda: bm.mgrid[:4])()) assertAllEqual(np_mgrid[:5, :5], bm.mgrid[:5, :5]) assertAllEqual(np_mgrid[:3, :2], bm.mgrid[:3, :2]) @@ -5453,7 +5453,7 @@ def assertListOfArraysEqual(xs, ys): self.assertArraysEqual(np_ogrid[:5], jax.jit(lambda: bm.ogrid[:5])()) self.assertArraysEqual(np_ogrid[1:7:2], bm.ogrid[1:7:2]) # List of arrays - assertListOfArraysEqual(np_ogrid[:5, ], bm.ogrid[:5, ]) + assertListOfArraysEqual(np_ogrid[:5,], bm.ogrid[:5,]) assertListOfArraysEqual(np_ogrid[0:5, 1:3], bm.ogrid[0:5, 1:3]) assertListOfArraysEqual(np_ogrid[1:3:2, 2:9:3], bm.ogrid[1:3:2, 2:9:3]) assertListOfArraysEqual(np_ogrid[:5, :9, :11], bm.ogrid[:5, :9, :11]) diff --git a/brainpy/mixin.py b/brainpy/mixin.py index 6df0f6e6b..d9afc05d5 100644 --- a/brainpy/mixin.py +++ b/brainpy/mixin.py @@ -17,9 +17,8 @@ from dataclasses import dataclass from typing import Union, Dict, Callable, Sequence, Optional, Any -import jax - import brainstate +import jax bm, delay_identifier, init_delay_by_return, DynamicalSystem = None, None, None, None diff --git a/brainpy/neurons.py b/brainpy/neurons.py index 6b1ce431a..4661f7452 100644 --- a/brainpy/neurons.py +++ b/brainpy/neurons.py @@ -17,42 +17,39 @@ This module has been deprecated since brainpy>=2.4.0. Use ``brainpy.dyn`` module instead. """ - +from brainpy.dyn.others import ( + InputGroup as InputGroup, + OutputGroup as OutputGroup, + SpikeTimeGroup as SpikeTimeGroup, + PoissonGroup as PoissonGroup, + Leaky as Leaky, + Integrator as Integrator, + OUProcess as OUProcess, +) from brainpy.dynold.neurons.biological_models import ( - HH as HH, - MorrisLecar as MorrisLecar, - PinskyRinzelModel as PinskyRinzelModel, - WangBuzsakiModel as WangBuzsakiModel, + HH as HH, + MorrisLecar as MorrisLecar, + PinskyRinzelModel as PinskyRinzelModel, + WangBuzsakiModel as WangBuzsakiModel, ) - from brainpy.dynold.neurons.fractional_models import ( - FractionalNeuron as FractionalNeuron, - FractionalFHR as FractionalFHR, - FractionalIzhikevich as FractionalIzhikevich, + FractionalNeuron as FractionalNeuron, + FractionalFHR as FractionalFHR, + FractionalIzhikevich as FractionalIzhikevich, ) - from brainpy.dynold.neurons.reduced_models import ( - LeakyIntegrator as LeakyIntegrator, - LIF as LIF, - ExpIF as ExpIF, - AdExIF as AdExIF, - QuaIF as QuaIF, - AdQuaIF as AdQuaIF, - GIF as GIF, - ALIFBellec2020 as ALIFBellec2020, - Izhikevich as Izhikevich, - HindmarshRose as HindmarshRose, - FHN as FHN, - LIF_SFA_Bellec2020, -) -from brainpy.dyn.others import ( - InputGroup as InputGroup, - OutputGroup as OutputGroup, - SpikeTimeGroup as SpikeTimeGroup, - PoissonGroup as PoissonGroup, - Leaky as Leaky, - Integrator as Integrator, - OUProcess as OUProcess, + LeakyIntegrator as LeakyIntegrator, + LIF as LIF, + ExpIF as ExpIF, + AdExIF as AdExIF, + QuaIF as QuaIF, + AdQuaIF as AdQuaIF, + GIF as GIF, + ALIFBellec2020 as ALIFBellec2020, + Izhikevich as Izhikevich, + HindmarshRose as HindmarshRose, + FHN as FHN, + LIF_SFA_Bellec2020, ) if __name__ == '__main__': @@ -85,4 +82,3 @@ Leaky Integrator OUProcess - diff --git a/brainpy/optim/optimizer.py b/brainpy/optim/optimizer.py index 584ef57e7..312b2376f 100644 --- a/brainpy/optim/optimizer.py +++ b/brainpy/optim/optimizer.py @@ -20,8 +20,8 @@ from jax.lax import cond import brainpy.math as bm -from brainpy._errors import MathError from brainpy import check +from brainpy._errors import MathError from brainpy.math.object_transform.base import BrainPyObject, ArrayCollector from .scheduler import make_schedule, Scheduler diff --git a/brainpy/optim/scheduler.py b/brainpy/optim/scheduler.py index 4a43c3e34..5941727a1 100644 --- a/brainpy/optim/scheduler.py +++ b/brainpy/optim/scheduler.py @@ -16,13 +16,13 @@ import warnings from typing import Sequence, Union +import brainstate import jax import jax.numpy as jnp import brainpy.math as bm -import brainstate -from brainpy._errors import MathError from brainpy import check +from brainpy._errors import MathError from brainpy.math.object_transform.base import BrainPyObject diff --git a/brainpy/rates.py b/brainpy/rates.py index 955c8b6d5..c0e24e1cc 100644 --- a/brainpy/rates.py +++ b/brainpy/rates.py @@ -19,7 +19,5 @@ from .dyn.rates import * - if __name__ == '__main__': FHN - diff --git a/brainpy/runners.py b/brainpy/runners.py index b0c98658d..8e107f51a 100644 --- a/brainpy/runners.py +++ b/brainpy/runners.py @@ -20,14 +20,14 @@ from collections.abc import Iterable from typing import Dict, Union, Sequence, Callable, Tuple, Optional, Any +import brainstate.environ import jax import jax.numpy as jnp import numpy as np from jax.tree_util import tree_map, tree_flatten -import brainstate.environ -from brainpy._errors import RunningError from brainpy import math as bm, tools +from brainpy._errors import RunningError from brainpy.context import share from brainpy.deprecations import _input_deprecate_msg from brainpy.dynsys import DynamicalSystem diff --git a/brainpy/running/runner.py b/brainpy/running/runner.py index 85afc11e9..ebd852825 100644 --- a/brainpy/running/runner.py +++ b/brainpy/running/runner.py @@ -19,8 +19,8 @@ import numpy as np -from brainpy._errors import MonitorError, RunningError from brainpy import math as bm, check +from brainpy._errors import MonitorError, RunningError from brainpy.math.object_transform.base import BrainPyObject from brainpy.tools import DotDict from . import constants as C diff --git a/brainpy/synouts.py b/brainpy/synouts.py index 5212630d0..6808a967c 100644 --- a/brainpy/synouts.py +++ b/brainpy/synouts.py @@ -25,6 +25,5 @@ MgBlock as MgBlock, ) - if __name__ == '__main__': - COBA, CUBA, MgBlock \ No newline at end of file + COBA, CUBA, MgBlock diff --git a/brainpy/test_main.py b/brainpy/test_main.py new file mode 100644 index 000000000..dab8e83f0 --- /dev/null +++ b/brainpy/test_main.py @@ -0,0 +1,18 @@ +# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +def test(): + import brainpy + print(brainpy.state) diff --git a/brainpy/tools/dicts.py b/brainpy/tools/dicts.py index 1d7c0d6f1..466298ef2 100644 --- a/brainpy/tools/dicts.py +++ b/brainpy/tools/dicts.py @@ -16,9 +16,8 @@ from typing import Union, Dict, Sequence import numpy as np -from jax.tree_util import register_pytree_node - from brainstate._compatible_import import safe_zip +from jax.tree_util import register_pytree_node __all__ = [ 'DotDict', diff --git a/brainpy/tools/others.py b/brainpy/tools/others.py index 728a6e6ce..59e2c7f83 100644 --- a/brainpy/tools/others.py +++ b/brainpy/tools/others.py @@ -132,4 +132,3 @@ def inner(*args, **kwargs): return inner return outer - diff --git a/brainpy/train/back_propagation.py b/brainpy/train/back_propagation.py index bb08bd13e..a820dc48e 100644 --- a/brainpy/train/back_propagation.py +++ b/brainpy/train/back_propagation.py @@ -17,6 +17,7 @@ from collections.abc import Iterable from typing import Union, Dict, Callable, Sequence, Optional +import brainstate.environ import jax.numpy as jnp import numpy as np from jax.tree_util import tree_map @@ -24,10 +25,9 @@ import brainpy.losses as losses import brainpy.math as bm -import brainstate.environ -from brainpy._errors import UnsupportedError, NoLongerSupportError from brainpy import optim from brainpy import tools +from brainpy._errors import UnsupportedError, NoLongerSupportError from brainpy.context import share from brainpy.dynsys import DynamicalSystem from brainpy.helpers import clear_input @@ -381,7 +381,8 @@ def fit( if self.loss_has_aux: test_epoch_metric['loss'].append(res[0]) if not isinstance(res[1], dict): - raise TypeError(f'Auxiliary data in loss function should be a dict. But we got {type(res)}') + raise TypeError( + f'Auxiliary data in loss function should be a dict. But we got {type(res)}') for k, v in res[1].items(): if k not in test_epoch_metric: test_epoch_metric[k] = [] diff --git a/brainpy/train/offline.py b/brainpy/train/offline.py index dfbee78c9..78863ec16 100644 --- a/brainpy/train/offline.py +++ b/brainpy/train/offline.py @@ -15,17 +15,17 @@ # ============================================================================== from typing import Dict, Sequence, Union, Callable, Any +import brainstate.environ import jax import numpy as np import tqdm.auto import brainpy.math as bm -import brainstate.environ -from brainpy.mixin import SupportOffline from brainpy import tools from brainpy.algorithms.offline import get, RidgeRegression, OfflineAlgorithm from brainpy.context import share from brainpy.dynsys import DynamicalSystem +from brainpy.mixin import SupportOffline from brainpy.runners import _call_fun_with_share from brainpy.types import ArrayType, Output from ._utils import format_ys diff --git a/brainpy/train/online.py b/brainpy/train/online.py index a0bacb099..e655a1355 100644 --- a/brainpy/train/online.py +++ b/brainpy/train/online.py @@ -16,18 +16,18 @@ import functools from typing import Dict, Sequence, Union, Callable +import brainstate.environ import jax import numpy as np import tqdm.auto from jax.tree_util import tree_map -import brainstate.environ -from brainpy.mixin import SupportOnline from brainpy import math as bm, tools from brainpy.algorithms.online import get, OnlineAlgorithm, RLS from brainpy.context import share from brainpy.dynsys import DynamicalSystem from brainpy.helpers import clear_input +from brainpy.mixin import SupportOnline from brainpy.runners import _call_fun_with_share from brainpy.types import ArrayType, Output from ._utils import format_ys diff --git a/brainpy/transform.py b/brainpy/transform.py index 00d834641..4422c6f53 100644 --- a/brainpy/transform.py +++ b/brainpy/transform.py @@ -287,7 +287,8 @@ def __call__( shared['i'] = jnp.arange(0, length[0]) + self.i0.value assert not self.no_state - xs = jax.tree.map(lambda x: x.value if isinstance(x, bm.Variable) else x, xs, is_leaf=lambda x: isinstance(x, bm.Variable)) + xs = jax.tree.map(lambda x: x.value if isinstance(x, bm.Variable) else x, xs, + is_leaf=lambda x: isinstance(x, bm.Variable)) results = bm.for_loop(functools.partial(self._run, self.shared_arg), (shared, xs), jit=self.jit) diff --git a/brainpy/visualization.py b/brainpy/visualization.py index 8a8f40bc2..d40a7c331 100644 --- a/brainpy/visualization.py +++ b/brainpy/visualization.py @@ -22,4 +22,3 @@ animate_2D = braintools.visualize.animate_2D remove_axis = braintools.visualize.remove_axis animator = braintools.visualize.animator - From 67889826d0bc86f7a94b85d2bf6064d620d8d245 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 25 Dec 2025 22:51:09 +0800 Subject: [PATCH 2/3] refactor: remove unused imports and clean up main execution block --- brainpy/__init__.py | 33 -------------------------------- brainpy/dyn/projections/utils.py | 1 + 2 files changed, 1 insertion(+), 33 deletions(-) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 2f84e37c3..04159b02d 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -150,36 +150,3 @@ import brainpy.state as state except: pass - -if __name__ == '__main__': - connect - initialize, # weight initialization - optim, # gradient descent optimizers - losses, # loss functions - measure, # methods for data analysis - inputs, # methods for generating input currents - encoding, # encoding schema - checkpoints, # checkpoints - check, # error checking - mixin, # mixin classes - algorithms, # online or offline training algorithms - check, tools, errors, math - BrainPyObject, - integrators, ode, sde, fde - Integrator, JointEq, IntegratorRunner, odeint, sdeint, fdeint - DynamicalSystem, DynSysGroup, Sequential, Dynamic, Projection - receive_update_input, receive_update_output, not_receive_update_input, not_receive_update_output - VarDelay - dnn, layers, dyn - NeuGroup, NeuGroupNS - share - reset_level, reset_state, save_state, load_state, clear_input - DSRunner, LoopOverTime, running - DSTrainer, BPTT, BPFF, OnlineTrainer, ForceTrainer, - OfflineTrainer, RidgeTrainer - analysis - visualize - train - channels, neurons, synapses, rates, synouts, synplast - Base - ArrayCollector, Collector, errors diff --git a/brainpy/dyn/projections/utils.py b/brainpy/dyn/projections/utils.py index 4ddfbdb57..dd58ff287 100644 --- a/brainpy/dyn/projections/utils.py +++ b/brainpy/dyn/projections/utils.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== from brainpy.mixin import ReturnInfo +import brainpy.math as bm def _get_return(return_info): From 07872deb2a926e34e042ff2abafefef886e3d765 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 25 Dec 2025 23:19:16 +0800 Subject: [PATCH 3/3] feat: add brainpy_state module and update dependencies --- brainpy/__init__.py | 8 ++++---- brainpy/state/README.md | 3 +++ brainpy/state/__init__.py | 20 ++++++++++++++++++++ pyproject.toml | 2 +- requirements.txt | 2 +- 5 files changed, 29 insertions(+), 6 deletions(-) create mode 100644 brainpy/state/README.md create mode 100644 brainpy/state/__init__.py diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 04159b02d..b41b72855 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -146,7 +146,7 @@ optimizers = optim -try: - import brainpy.state as state -except: - pass + +# New package +from brainpy import state + diff --git a/brainpy/state/README.md b/brainpy/state/README.md new file mode 100644 index 000000000..75548901c --- /dev/null +++ b/brainpy/state/README.md @@ -0,0 +1,3 @@ +# ``brainpy.state`` README + +This module is being maintained by [brainpy_state](https://github.com/chaobrain/brainpy.state). diff --git a/brainpy/state/__init__.py b/brainpy/state/__init__.py new file mode 100644 index 000000000..4f9491a40 --- /dev/null +++ b/brainpy/state/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from brainpy_state import * +from brainpy_state import __all__ + + + diff --git a/pyproject.toml b/pyproject.toml index 401735954..f10a14666 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ dependencies = [ "brainunit", "brainevent>=0.0.4", "braintools>=0.0.9", - 'brainpy-state', + 'brainpy_state>=0.0.2', ] dynamic = ['version'] diff --git a/requirements.txt b/requirements.txt index 8d11828e9..726f19bb3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,6 @@ brainunit brainevent>=0.0.4 braintools>=0.1.0 brainstate>=0.2.7 -brainpy-state +brainpy_state>=0.0.2 jax tqdm