From 0233d8c1f444a19dd6470be76c0aa50c5e794dfc Mon Sep 17 00:00:00 2001 From: Adam Date: Mon, 16 Jun 2025 18:22:49 +0100 Subject: [PATCH 01/11] feat: add `target_strain` and `target_strain_rate` args to `LoadStep.unaxial` for convenience --- matflow/param_classes/load.py | 44 +++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/matflow/param_classes/load.py b/matflow/param_classes/load.py index a7a95ac8..06f46e98 100644 --- a/matflow/param_classes/load.py +++ b/matflow/param_classes/load.py @@ -1,6 +1,7 @@ """ Loadings to apply to a simulated sample. """ + from __future__ import annotations from collections.abc import Callable, Iterator import copy @@ -255,6 +256,8 @@ def uniaxial( total_time: float | int, num_increments: int, direction: str, + target_strain: float | None = None, + target_strain_rate: float | None = None, target_def_grad_rate: float | None = None, target_def_grad: float | None = None, dump_frequency: int = 1, @@ -272,9 +275,15 @@ def uniaxial( A single character, "x", "y" or "z", representing the loading direction. target_def_grad : float Target deformation gradient to achieve along the loading direction component. + target_strain: float + Target engineering strain to achieve along the loading direction. Specify at + most one of `target_strain` and `target_def_grad`. target_def_grad_rate : float Target deformation gradient rate to achieve along the loading direction component. + target_strain_rate: float + Target engineering strain rate to achieve along the loading direction. Specify + at most one of `target_strain_rate` and `target_def_grad_rate`. dump_frequency : int, optional By default, 1, meaning results are written out every increment. """ @@ -284,22 +293,37 @@ def uniaxial( "total_time": total_time, "num_increments": num_increments, "direction": direction, - "target_def_grad_rate": target_def_grad_rate, + "target_strain": target_strain, + "target_strain_rate": target_strain_rate, "target_def_grad": target_def_grad, + "target_def_grad_rate": target_def_grad_rate, "dump_frequency": dump_frequency, } # Validation: - msg = "Specify either `target_def_grad_rate` or `target_def_grad`." - if all([t is None for t in [target_def_grad_rate, target_def_grad]]): - raise ValueError(msg) - if all([t is not None for t in [target_def_grad_rate, target_def_grad]]): + msg = ( + "Specify either `target_strain`, `target_strain_rate`, " + "``target_def_grad` or target_def_grad_rate`." + ) + strain_arg = ( + target_strain, + target_strain_rate, + target_def_grad, + target_def_grad_rate, + ) + if sum(s is not None for s in strain_arg) != 1: raise ValueError(msg) - if target_def_grad_rate is not None: - def_grad_val = target_def_grad_rate + # convert strain (rate) to deformation gradient (rate) components: + t_dg = 1 + target_strain if target_strain is not None else target_def_grad + t_dg_rate = ( + target_strain_rate if target_strain_rate is not None else target_def_grad_rate + ) + + if t_dg_rate is not None: + def_grad_val = t_dg_rate else: - def_grad_val = target_def_grad + def_grad_val = t_dg try: loading_dir_idx = cls._DIR_IDX.index(direction) @@ -317,8 +341,8 @@ def uniaxial( dg_arr.mask[loading_dir_idx, loading_dir_idx] = False stress_arr.mask[loading_dir_idx, loading_dir_idx] = True - def_grad_aim = dg_arr if target_def_grad is not None else None - def_grad_rate = dg_arr if target_def_grad_rate is not None else None + def_grad_aim = dg_arr if t_dg is not None else None + def_grad_rate = dg_arr if t_dg_rate is not None else None obj = cls( direction=direction, From d09d26e483a0252b9ba8fbec0ad13ee5359414a5 Mon Sep 17 00:00:00 2001 From: Adam Date: Wed, 18 Jun 2025 11:20:30 +0100 Subject: [PATCH 02/11] refactor: ensure both `target_def_grad_(rate)` and `strain_(rate)` are populated in `LoadStep.uniaxial` --- matflow/param_classes/load.py | 44 +++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/matflow/param_classes/load.py b/matflow/param_classes/load.py index 06f46e98..8a86341a 100644 --- a/matflow/param_classes/load.py +++ b/matflow/param_classes/load.py @@ -214,6 +214,24 @@ def type(self) -> str: """More user-friendly access to method name.""" return self._method_name or self.__class__.__name__ + @property + def strain(self) -> float | None: + """ + For a limited subset of load step types (e.g. uniaxial), return the scalar target + strain. + """ + if self.type in ("uniaxial",): + return self.method_args["target_strain"] + + @property + def strain_rate(self) -> float | None: + """ + For a limited subset of load step types (e.g. uniaxial), return the scalar target + strain rate. + """ + if self.type in ("uniaxial",): + return self.method_args["target_strain_rate"] + def __repr__(self) -> str: type_str = f"type={self.type!r}, " if self.type else "" if self.direction: @@ -314,16 +332,22 @@ def uniaxial( if sum(s is not None for s in strain_arg) != 1: raise ValueError(msg) - # convert strain (rate) to deformation gradient (rate) components: - t_dg = 1 + target_strain if target_strain is not None else target_def_grad - t_dg_rate = ( - target_strain_rate if target_strain_rate is not None else target_def_grad_rate - ) + # convert strain (rate) to deformation gradient (rate) components, and ensure both + # strain(_rate) and def_grad(_rate) are populated: + if target_strain is not None: + target_def_grad = 1 + target_strain + elif target_def_grad is not None: + target_strain = target_def_grad - 1 - if t_dg_rate is not None: - def_grad_val = t_dg_rate + if target_strain_rate is not None: + target_def_grad_rate = target_strain_rate + elif target_def_grad_rate is not None: + target_strain_rate = target_def_grad_rate + + if target_def_grad_rate is not None: + def_grad_val = target_def_grad_rate else: - def_grad_val = t_dg + def_grad_val = target_def_grad try: loading_dir_idx = cls._DIR_IDX.index(direction) @@ -341,8 +365,8 @@ def uniaxial( dg_arr.mask[loading_dir_idx, loading_dir_idx] = False stress_arr.mask[loading_dir_idx, loading_dir_idx] = True - def_grad_aim = dg_arr if t_dg is not None else None - def_grad_rate = dg_arr if t_dg_rate is not None else None + def_grad_aim = dg_arr if target_def_grad is not None else None + def_grad_rate = dg_arr if target_def_grad_rate is not None else None obj = cls( direction=direction, From 1ff2760f0100d4181ed3678b1c456d155fb05ccb Mon Sep 17 00:00:00 2001 From: Adam Date: Wed, 18 Jun 2025 13:23:19 +0100 Subject: [PATCH 03/11] feat: add `BoundaryCondition` class in readiness for JAX-CPFEM load case support --- matflow/param_classes/boundary_conditions.py | 155 +++++++++++++++++++ matflow/param_classes/load.py | 23 +++ 2 files changed, 178 insertions(+) create mode 100644 matflow/param_classes/boundary_conditions.py diff --git a/matflow/param_classes/boundary_conditions.py b/matflow/param_classes/boundary_conditions.py new file mode 100644 index 00000000..6bf0c507 --- /dev/null +++ b/matflow/param_classes/boundary_conditions.py @@ -0,0 +1,155 @@ +""" +A boundary condition class to represent some types of load case in a way amenable to the +JAX-CPFEM crystal plasticity code. This is not a `ParameterValue` sub-class, but rather a +helper class. + +""" + +from collections.abc import Mapping, Sequence +from typing import Literal + +from typing_extensions import Final, Self + + +class BoundaryCondition: + """Simple boundary conditions container that can be used to represent a subset of + particular load cases, as corners or faces of a unit box, and the value that should be + applied within those regions. + + Parameters + ---------- + corners + Corners of the 3D box for which boundary condition should apply. Each corner + should be specified as a three-tuple of 0s or 1s. For example `(0, 0, 0)` + corresponds to the box origin, and `(0, 1, 1)` corresponds to the `x=0`, `y=1`, + `z=1` corner. Specify either `corners` or `faces`. + faces + A sequence of strings, where each string identifies the normal direction of a + specified box face, when viewed from the middle of the box. For example, the `+z` + direction corresponds to the `z=1` face, and the `-x` direction corresponds to the + `x=0` face. Specify either `corners` or `faces`. + value + A value for one or more of the "x", "y", and "z" components of the field. This can + be specified as a number or a string label that has some application-dependent + meaning. + + """ + + _DIRS: Final[tuple[str, ...]] = ("x", "y", "z") + + _JAX_CPFEM_BC_FUNC_CORNER_MAP = { + (0, 0, 0): "corner_0", + (1, 0, 0): "corner_1", + (1, 1, 0): "corner_2", + (0, 1, 0): "corner_3", + (0, 0, 1): "corner_4", + (1, 0, 1): "corner_5", + (1, 1, 1): "corner_6", + (0, 1, 1): "corner_7", + } + _JAX_CPFEM_BC_FUNC_FACE_MAP = { + "+x": "face_pos_x", + "-x": "face_neg_x", + "+y": "face_pos_y", + "-y": "face_neg_y", + "+z": "face_pos_z", + "-z": "face_neg_z", + } + _LOCATION_FUNC_CORNER_BODY_TEMPLATE = "return np.allclose(point, {corner}, atol=1e-5)" + _LOCATION_FUNC_FACE_BODY_TEMPLATE = ( + "return np.isclose(point[{comp_idx}], {value}, atol=1e-5)" + ) + + def __init__( + self, + value: Mapping[Literal["x", "y", "z"], float | str], + corners: ( + Sequence[Sequence[Literal[0, 1], Literal[0, 1], Literal[0, 1]]] | None + ) = None, + faces: Sequence[str] | None = None, + ): + if sum(i is not None for i in (corners, faces)) != 1: + raise ValueError("Specify exactly one of `corners` and `faces`.") + + self.corners = corners + self.faces = faces + self.value = value + + def __repr__(self): + region_name = "corners" if self.corners is not None else "faces" + region = self.corners if self.corners is not None else self.faces + return ( + f"{self.__class__.__name__}(" + f"{region_name}={region!r}, " + f"value={self.value!r}" + f")" + ) + + @classmethod + def uniaxial_tension(cls, direction: Literal["x", "y", "z"]) -> list[Self]: + """Generate a list of boundary conditions that correspond to uniaxial loading + along the specified direction.""" + non_axial_dirs = sorted(list(set(cls._DIRS).difference(direction))) + return [ + cls(corners=([0, 0, 0],), value={non_axial_dirs[0]: 0, non_axial_dirs[1]: 0}), + cls(faces=(f"-{direction}",), value={direction: 0}), + cls(faces=(f"+{direction}",), value={direction: "u(t)"}), + ] + + @staticmethod + def get_corner_location_func_str(corner) -> str: + """Generate Python code that defines a function that return True if the provided + 3D point is located at the specified corner, and False otherwise. + + Notes + ----- + The generated code is used as part of the problem definition script when running a + JAX-CPFEM simulation. + + """ + func_str = dedent( + """\ + def {func_name}(point): + {func_body} + + """ + ).format( + func_name=BoundaryCondition._JAX_CPFEM_BC_FUNC_CORNER_MAP[tuple(corner)], + func_body=BoundaryCondition._LOCATION_FUNC_CORNER_BODY_TEMPLATE.format( + corner=f"[{', '.join(str(i) for i in corner)}]" + ), + ) + return func_str + + @staticmethod + def get_face_location_func_str(face, domain_size) -> str: + """Generate Python code that defines a function that return True if the provided + 3D point is located within the specified plane (see the `BoundaryCondition.faces` + documentation for specification details), and False otherwise. + + Notes + ----- + The generated code is used as part of the problem definition script when running a + JAX-CPFEM simulation. + + """ + face = face.lower() + if len(face) == 1: + face = f"+{face}" + sign, dir = face + + comp_idx = BoundaryCondition._DIRS.index(dir) + func_str = dedent( + """\ + def {func_name}(point): + {func_body} + + """ + ).format( + func_name=BoundaryCondition._JAX_CPFEM_BC_FUNC_FACE_MAP[face], + func_body=BoundaryCondition._LOCATION_FUNC_FACE_BODY_TEMPLATE.format( + comp_idx=comp_idx, + value=domain_size[comp_idx] if sign == "+" else 0, + ), + ) + return func_str diff --git a/matflow/param_classes/load.py b/matflow/param_classes/load.py index 8a86341a..a80a224f 100644 --- a/matflow/param_classes/load.py +++ b/matflow/param_classes/load.py @@ -19,6 +19,7 @@ import matflow as mf from matflow.param_classes.utils import masked_array_from_list +from matflow.param_classes.boundary_conditions import BoundaryCondition logger = logging.getLogger(__name__) @@ -247,6 +248,16 @@ def __repr__(self) -> str: f")" ) + def to_dirichlet_BCs(self) -> list[BoundaryCondition]: + """For some particular types of load steps (e.g. uniaxial), we can transform them + into Dirichlet boundary conditions.""" + + if self.type == "uniaxial": + return BoundaryCondition.uniaxial_tension(self.direction) + raise NotImplementedError( + "Cannot express this load step in terms of boundary conditions." + ) + @classmethod def example_uniaxial(cls) -> Self: """ @@ -988,6 +999,18 @@ def create_damask_loading_plan(self) -> list[dict[str, Any]]: load_steps.append(dct) return load_steps + def to_dirichlet_BCs(self) -> list[BoundaryCondition]: + """For some particular types of load cases (e.g. uniaxial), we can transform them + into Dirichlet boundary conditions.""" + + if self.num_steps == 1: + return self.steps[0].to_dirichlet_BCs() + else: + raise NotImplementedError( + "It is not currently possible to express multi-step load cases in terms " + "of boundary conditions." + ) + @classmethod def uniaxial(cls, **kwargs) -> Self: """A single-step uniaxial load case. From a6332f8d0484476376955f1252462593e73a5784 Mon Sep 17 00:00:00 2001 From: Adam Date: Wed, 18 Jun 2025 13:47:32 +0100 Subject: [PATCH 04/11] feat: initial attempt at JAX-CPFEM support --- .../data/scripts/jax_cpfem/parse_stdout.py | 49 ++ matflow/data/scripts/jax_cpfem/write_mesh.py | 14 + .../jax_cpfem/write_model_py_script.py | 429 ++++++++++++++++++ .../jax_cpfem/write_problem_py_script.py | 416 +++++++++++++++++ .../scripts/jax_cpfem/write_slip_systems.py | 26 ++ .../workflows/simulate_JAX_CPFEM_copper.yaml | 150 ++++++ .../workflows/simulate_JAX_CPFEM_steel.yaml | 175 +++++++ 7 files changed, 1259 insertions(+) create mode 100644 matflow/data/scripts/jax_cpfem/parse_stdout.py create mode 100644 matflow/data/scripts/jax_cpfem/write_mesh.py create mode 100644 matflow/data/scripts/jax_cpfem/write_model_py_script.py create mode 100644 matflow/data/scripts/jax_cpfem/write_problem_py_script.py create mode 100644 matflow/data/scripts/jax_cpfem/write_slip_systems.py create mode 100644 matflow/data/workflows/simulate_JAX_CPFEM_copper.yaml create mode 100644 matflow/data/workflows/simulate_JAX_CPFEM_steel.yaml diff --git a/matflow/data/scripts/jax_cpfem/parse_stdout.py b/matflow/data/scripts/jax_cpfem/parse_stdout.py new file mode 100644 index 00000000..e4a4e85c --- /dev/null +++ b/matflow/data/scripts/jax_cpfem/parse_stdout.py @@ -0,0 +1,49 @@ +import os +import re +from pathlib import Path + +import numpy as np + +_PAT_FP = ( + r"Fp\s*=\s*\[\[\s*([-+]?\d*\.\d*\s+[-+]?\d*\.\d*\s+[-+]?\d*\.\d*)\s*\]\s*" + r"\[\s*([-+]?\d*\.\d*\s+[-+]?\d*\.\d*\s+[-+]?\d*\.\d*)\s*\]\s*" + r"\[\s*([-+]?\d*\.\d*\s+[-+]?\d*\.\d*\s+[-+]?\d*\.\d*)\s*\]\s*\]" +) +_PAT_key_data_list = R"{key}:\s*\[\s*((?:\d+\.\d+\s*)+)\]" +_PAT_STRESS_XX = _PAT_key_data_list.format(key="stress_xx") +_PAT_STRESS_YY = _PAT_key_data_list.format(key="stress_yy") +_PAT_STRESS_ZZ = _PAT_key_data_list.format(key="stress_zz") +_PAT_STRESS_VM = _PAT_key_data_list.format(key="von_mises_stress") + + +def parse_stdout(jax_cpfem_stdout: Path): + + encoding = "utf-16" if os.name == "nt" else "utf-8" # docker on windows nonsense + with jax_cpfem_stdout.open("rt", encoding=encoding) as fh: + contents = fh.read() + + fp = [] + for fp_res in re.findall(_PAT_FP, contents): + fp.append([[float(i) for i in fp_row.split()] for fp_row in fp_res]) + fp_arr = np.asarray(fp) + + stress_xx = np.array( + [float(i) for i in re.search(_PAT_STRESS_XX, contents).groups()[0].split()] + ) + stress_yy = np.array( + [float(i) for i in re.search(_PAT_STRESS_YY, contents).groups()[0].split()] + ) + stress_zz = np.array( + [float(i) for i in re.search(_PAT_STRESS_ZZ, contents).groups()[0].split()] + ) + stress_von_mises = np.array( + [float(i) for i in re.search(_PAT_STRESS_VM, contents).groups()[0].split()] + ) + + return { + "Fp": fp_arr, + "stress_xx": stress_xx, + "stress_yy": stress_yy, + "stress_zz": stress_zz, + "stress_von_mises": stress_von_mises, + } diff --git a/matflow/data/scripts/jax_cpfem/write_mesh.py b/matflow/data/scripts/jax_cpfem/write_mesh.py new file mode 100644 index 00000000..9a3497d6 --- /dev/null +++ b/matflow/data/scripts/jax_cpfem/write_mesh.py @@ -0,0 +1,14 @@ +import numpy as np + + +def write_mesh(path, volume_element): + np.savez( + path, + num_cells=volume_element["grid_size"][:], # i.e. Nx, Ny, Nz + domain=np.asarray(volume_element["size"]), # i.e. domain_x, domain_y, domain_z + cell_grain_indices=volume_element["element_material_idx"][:].reshape( + -1 + ), # TODO: check order! + quaternions=volume_element["orientations"]["quaternions"][:], + grain_orientation_indices=volume_element["constituent_orientation_idx"][:], + ) diff --git a/matflow/data/scripts/jax_cpfem/write_model_py_script.py b/matflow/data/scripts/jax_cpfem/write_model_py_script.py new file mode 100644 index 00000000..0857dd78 --- /dev/null +++ b/matflow/data/scripts/jax_cpfem/write_model_py_script.py @@ -0,0 +1,429 @@ +from textwrap import dedent + + +def write_model_py_script(path, material_parameters, numerics): + + TEMPLATE = dedent( + """\ + import numpy as onp + import jax + import jax.numpy as np + import jax.flatten_util + import os + import sys + from functools import partial + + from jax_fem.problem import Problem + from jax import config + + config.update("jax_enable_x64", True) + + onp.set_printoptions(threshold=sys.maxsize, linewidth=1000, suppress=True, precision=10) + + + crt_dir = os.path.dirname(__file__) + + + def rotate_tensor_rank_4(R, T): + R0 = R[:, :, None, None, None, None, None, None] + R1 = R[None, None, :, :, None, None, None, None] + R2 = R[None, None, None, None, :, :, None, None] + R3 = R[None, None, None, None, None, None, :, :] + return np.sum( + R0 * R1 * R2 * R3 * T[None, :, None, :, None, :, None, :], axis=(1, 3, 5, 7) + ) + + + def rotate_tensor_rank_2(R, T): + R0 = R[:, :, None, None] + R1 = R[None, None, :, :] + return np.sum(R0 * R1 * T[None, :, None, :], axis=(1, 3)) + + + rotate_tensor_rank_2_vmap = jax.vmap(rotate_tensor_rank_2, in_axes=(None, 0)) + + + def get_rot_mat(q): + ''' + Transformation from quaternion to the corresponding rotation matrix. + Reference: https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation + ''' + ## Hu: return rotation matrix -- (3,3) + return np.array( + [ + [ + q[0] * q[0] + q[1] * q[1] - q[2] * q[2] - q[3] * q[3], + 2 * q[1] * q[2] - 2 * q[0] * q[3], + 2 * q[1] * q[3] + 2 * q[0] * q[2], + ], + [ + 2 * q[1] * q[2] + 2 * q[0] * q[3], + q[0] * q[0] - q[1] * q[1] + q[2] * q[2] - q[3] * q[3], + 2 * q[2] * q[3] - 2 * q[0] * q[1], + ], + [ + 2 * q[1] * q[3] - 2 * q[0] * q[2], + 2 * q[2] * q[3] + 2 * q[0] * q[1], + q[0] * q[0] - q[1] * q[1] - q[2] * q[2] + q[3] * q[3], + ], + ] + ) + + + get_rot_mat_vmap = jax.vmap(get_rot_mat) + + + class CrystalPlasticity(Problem): + def custom_init(self, quat, cell_ori_inds): + r = {r} + self.gss_initial = {gss_initial} + + input_slip_sys = onp.loadtxt("input_slip_sys.txt") + num_slip_sys = len(input_slip_sys) + + slip_directions = input_slip_sys[:, self.dim :] + slip_directions = ( + slip_directions / onp.linalg.norm(slip_directions, axis=1)[:, None] + ) + slip_normals = input_slip_sys[:, : self.dim] + slip_normals = slip_normals / onp.linalg.norm(slip_normals, axis=1)[:, None] + + self.Schmid_tensors = jax.vmap(np.outer)(slip_directions, slip_normals) + + self.q = r * onp.ones((num_slip_sys, num_slip_sys)) + + num_directions_per_normal = 3 + for i in range(num_slip_sys): + for j in range(num_directions_per_normal): + self.q[ + i, i // num_directions_per_normal * num_directions_per_normal + j + ] = 1.0 + + rot_mats = onp.array(get_rot_mat_vmap(quat)[cell_ori_inds]) + + ### Note: for CPFEM, self.num_vars=1, which means fes only have 1 Finite Element object + ### Multi-physics CPFEM is under study + ## Hu: Fp - plastic deformation gradient + Fp_inv_gp = onp.repeat( + onp.repeat( + onp.eye(self.dim)[None, None, :, :], len(self.fes[0].cells), axis=0 + ), + self.fes[0].num_quads, + axis=1, + ) + ## Hu: slip resistance + slip_resistance_gp = self.gss_initial * onp.ones( + (len(self.fes[0].cells), self.fes[0].num_quads, num_slip_sys) + ) + ## Hu: slip rate + slip_gp = onp.zeros_like(slip_resistance_gp) + ## Hu: rotation matrix + rot_mats_gp = onp.repeat(rot_mats[:, None, :, :], self.fes[0].num_quads, axis=1) + ## Hu: elastic modulus + self.C = onp.zeros((self.dim, self.dim, self.dim, self.dim)) + + ## Hu: unit: MPa + C11 = {C11} + C12 = {C12} + C44 = {C44} + + self.C[0, 0, 0, 0] = C11 + self.C[1, 1, 1, 1] = C11 + self.C[2, 2, 2, 2] = C11 + + self.C[0, 0, 1, 1] = C12 + self.C[1, 1, 0, 0] = C12 + + self.C[0, 0, 2, 2] = C12 + self.C[2, 2, 0, 0] = C12 + + self.C[1, 1, 2, 2] = C12 + self.C[2, 2, 1, 1] = C12 + + self.C[1, 2, 1, 2] = C44 + self.C[1, 2, 2, 1] = C44 + self.C[2, 1, 1, 2] = C44 + self.C[2, 1, 2, 1] = C44 + + self.C[2, 0, 2, 0] = C44 + self.C[2, 0, 0, 2] = C44 + self.C[0, 2, 2, 0] = C44 + self.C[0, 2, 0, 2] = C44 + + self.C[0, 1, 0, 1] = C44 + self.C[0, 1, 1, 0] = C44 + self.C[1, 0, 0, 1] = C44 + self.C[1, 0, 1, 0] = C44 + + ## Hu: internal variables + self.internal_vars = [Fp_inv_gp, slip_resistance_gp, slip_gp, rot_mats_gp] + + def get_tensor_map(self): + tensor_map, _ = self.get_maps() + return tensor_map + + def get_maps(self): + ## Hu: initial hardening, unit: MPa + h = {h} + ## Hu: saturation strength, unit: MPa + t_sat = {t_sat} + ## Hu: 'a' in Kalidini model + gss_a = {gss_a} + ## Hu: reference strain rate + ao = {ao} + ## Hu: rate sensitivity exponent + xm = 1.0 / {xm_d} + + def get_partial_tensor_map(Fp_inv_old, slip_resistance_old, slip_old, rot_mat): + _, unflatten_fn = jax.flatten_util.ravel_pytree(Fp_inv_old) + _, unflatten_fn_params = jax.flatten_util.ravel_pytree( + [Fp_inv_old, Fp_inv_old, slip_resistance_old, slip_old, rot_mat] + ) + + def first_PK_stress(u_grad): + x, _ = jax.flatten_util.ravel_pytree( + [u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat] + ) + y = newton_solver(x) + S = unflatten_fn(y) + _, _, _, Fe, F = helper( + u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat, S + ) + sigma = 1.0 / np.linalg.det(Fe) * Fe @ S @ Fe.T + P = np.linalg.det(F) * sigma @ np.linalg.inv(F).T + return P + + def update_int_vars(u_grad): + x, _ = jax.flatten_util.ravel_pytree( + [u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat] + ) + y = newton_solver(x) + S = unflatten_fn(y) + Fp_inv_new, slip_resistance_new, slip_new, Fe, F = helper( + u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat, S + ) + return Fp_inv_new, slip_resistance_new, slip_new, rot_mat + + ## Hu: S-based formulation + def helper(u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat, S): + tau = np.sum( + S[None, :, :] + * rotate_tensor_rank_2_vmap(rot_mat, self.Schmid_tensors), + axis=(1, 2), + ) + gamma_inc = ( + ao + * self.dt + * np.absolute(tau / slip_resistance_old) ** (1.0 / xm) + * np.sign(tau) + ) + + tmp = ( + h + * np.absolute(gamma_inc) + * np.absolute(1 - slip_resistance_old / t_sat) ** gss_a + * np.sign(1 - slip_resistance_old / t_sat) + ) + g_inc = (self.q @ tmp[:, None]).reshape(-1) + + # tmp = h*np.absolute(gamma_inc) / np.cosh(h*np.sum(slip_old)/(t_sat - self.gss_initial))**2 + # g_inc = (self.q @ tmp[:, None]).reshape(-1) + + slip_resistance_new = slip_resistance_old + g_inc + slip_new = slip_old + gamma_inc + F = u_grad + np.eye(self.dim) + L_plastic_inc = np.sum( + gamma_inc[:, None, None] + * rotate_tensor_rank_2_vmap(rot_mat, self.Schmid_tensors), + axis=0, + ) + Fp_inv_new = Fp_inv_old @ (np.eye(self.dim) - L_plastic_inc) + Fe = F @ Fp_inv_new + return Fp_inv_new, slip_resistance_new, slip_new, Fe, F + + ## Hu: Calculate the residual function + def implicit_residual(x, y): + u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat = ( + unflatten_fn_params(x) + ) + S = unflatten_fn(y) + _, _, _, Fe, _ = helper( + u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat, S + ) + S_ = np.sum( + rotate_tensor_rank_4(rot_mat, self.C) + * 1.0 + / 2.0 + * (Fe.T @ Fe - np.eye(self.dim))[None, None, :, :], + axis=(2, 3), + ) + res, _ = jax.flatten_util.ravel_pytree(S - S_) + return res + + ## Hu: inner Newton's method + @jax.custom_jvp + def newton_solver(x): + # Critical change: The following line causes JAX (version 0.4.13) tracer error + # y0 = np.zeros_like(Fp_inv_old.reshape(-1)) + y0 = np.zeros(self.dim * self.dim) + + step = 0 + res_vec = implicit_residual(x, y0) + tol = 1e-8 + + def cond_fun(state): + step, res_vec, y = state + ## Hu: In MOOSE: while (rnorm > _rtol * rnorm0 && rnorm > _abs_tol && iteration < _maxiter) + return np.linalg.norm(res_vec) > tol + + def body_fun(state): + # Line search with decaying relaxation parameter (the "cut half" method). + # This is necessary since vanilla Newton's method may sometimes not converge. + # MOOSE has an implementation in C++, see the following link + # https://github.com/idaholab/moose/blob/next/modules/tensor_mechanics/src/materials/ + # crystal_plasticity/ComputeMultipleCrystalPlasticityStress.C#L634 + step, res_vec, y = state + ## Hu: Input: S, Output: res + f_partial = lambda y: implicit_residual(x, y) + jac = jax.jacfwd(f_partial)(y) + y_inc = np.linalg.solve(jac, -res_vec) + + relax_param_ini = 1.0 + sub_step_ini = 0 + max_sub_step = {newton_max_sub_step} + + def sub_cond_fun(state): + _, crt_res_vec, sub_step = state + return np.logical_and( + np.linalg.norm(crt_res_vec) >= np.linalg.norm(res_vec), + sub_step < max_sub_step, + ) + + def sub_body_fun(state): + relax_param, res_vec, sub_step = state + res_vec = f_partial(y + relax_param * y_inc) + return 0.5 * relax_param, res_vec, sub_step + 1 + + relax_param_f, res_vec_f, _ = jax.lax.while_loop( + sub_cond_fun, + sub_body_fun, + (relax_param_ini, res_vec, sub_step_ini), + ) + step_update = step + 1 + + return step_update, res_vec_f, y + 2.0 * relax_param_f * y_inc + + step_f, res_vec_f, y_f = jax.lax.while_loop( + cond_fun, body_fun, (step, res_vec, y0) + ) + + return y_f + + @newton_solver.defjvp + def f_jvp(primals, tangents): + (x,) = primals + (v,) = tangents + y = newton_solver(x) + jac_x = jax.jacfwd(implicit_residual, argnums=0)(x, y) + jac_y = jax.jacfwd(implicit_residual, argnums=1)(x, y) + jvp_result = np.linalg.solve(jac_y, -(jac_x @ v[:, None]).reshape(-1)) + return y, jvp_result + + return first_PK_stress, update_int_vars + + def tensor_map(u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat): + first_PK_stress, _ = get_partial_tensor_map( + Fp_inv_old, slip_resistance_old, slip_old, rot_mat + ) + return first_PK_stress(u_grad) + + def update_int_vars_map( + u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat + ): + _, update_int_vars = get_partial_tensor_map( + Fp_inv_old, slip_resistance_old, slip_old, rot_mat + ) + return update_int_vars(u_grad) + + return tensor_map, update_int_vars_map + + def update_int_vars_gp(self, sol, params): + _, update_int_vars_map = self.get_maps() + vmap_update_int_vars_map = jax.jit(jax.vmap(jax.vmap(update_int_vars_map))) + # (num_cells, 1, num_nodes, vec, 1) * (num_cells, num_quads, num_nodes, 1, dim) -> (num_cells, num_quads, num_nodes, vec, dim) + u_grads = ( + np.take(sol, self.fes[0].cells, axis=0)[:, None, :, :, None] + * self.fes[0].shape_grads[:, :, :, None, :] + ) + u_grads = np.sum(u_grads, axis=2) # (num_cells, num_quads, vec, dim) + Fp_inv_gp, slip_resistance_gp, slip_gp, rot_mats_gp = vmap_update_int_vars_map( + u_grads, *params + ) + # TODO + return [Fp_inv_gp, slip_resistance_gp, slip_gp, rot_mats_gp] + + def set_params(self, params): + self.internal_vars = params + + def inspect_interval_vars(self, params): + '''For post-processing only''' + Fp_inv_gp, slip_resistance_gp, slip_gp, rot_mats_gp = params + F_p = np.linalg.inv(Fp_inv_gp[0, 0]) + print(f"Fp = \\n{{F_p}}") + slip_resistance_0 = slip_resistance_gp[0, 0, 0] + print( + f"slip_resistance index 0 = {{slip_resistance_0}}, max slip_resistance = {{np.max(slip_resistance_gp)}}" + ) + return F_p[2, 2], slip_resistance_0, slip_gp[0, 0, 0] + + def compute_avg_stress(self, sol, params): + '''For post-processing only''' + # (num_cells, 1, num_nodes, vec, 1) * (num_cells, num_quads, num_nodes, 1, dim) -> (num_cells, num_quads, num_nodes, vec, dim) + u_grads = ( + np.take(sol, self.fes[0].cells, axis=0)[:, None, :, :, None] + * self.fes[0].shape_grads[:, :, :, None, :] + ) + u_grads = np.sum(u_grads, axis=2) # (num_cells, num_quads, vec, dim) + + partial_tensor_map, _ = self.get_maps() + vmap_partial_tensor_map = jax.jit(jax.vmap(jax.vmap(partial_tensor_map))) + P = vmap_partial_tensor_map(u_grads, *params) + + def P_to_sigma(P, F): + return 1.0 / np.linalg.det(F) * P @ F.T + + vvmap_P_to_sigma = jax.vmap(jax.vmap(P_to_sigma)) + F = u_grads + np.eye(self.dim)[None, None, :, :] + sigma = vvmap_P_to_sigma(P, F) + + sigma_cell_data = ( + np.sum(sigma * self.fes[0].JxW[:, :, None, None], 1) + / np.sum(self.fes[0].JxW, axis=1)[:, None, None] + ) + + # num_cells*num_quads, vec, dim) * (num_cells*num_quads, 1, 1) + avg_P = np.sum( + P.reshape(-1, self.fes[0].vec, self.dim) + * self.fes[0].JxW.reshape(-1)[:, None, None], + 0, + ) / np.sum(self.fes[0].JxW) + return sigma_cell_data + """ + ) + + with path.open("wt") as fh: + fh.write( + TEMPLATE.format( + C11=material_parameters["C11"], + C12=material_parameters["C12"], + C44=material_parameters["C44"], + r=material_parameters["r"], + gss_initial=material_parameters["gss_initial"], + h=material_parameters["h"], + t_sat=material_parameters["t_sat"], + gss_a=material_parameters["gss_a"], + ao=material_parameters["ao"], + xm_d=material_parameters["xm_d"], + newton_max_sub_step=numerics["newton_max_sub_step"], + ) + ) diff --git a/matflow/data/scripts/jax_cpfem/write_problem_py_script.py b/matflow/data/scripts/jax_cpfem/write_problem_py_script.py new file mode 100644 index 00000000..5df5682e --- /dev/null +++ b/matflow/data/scripts/jax_cpfem/write_problem_py_script.py @@ -0,0 +1,416 @@ +from textwrap import dedent, indent +from collections.abc import Sequence + +_DIR_LOOKUP = { + "x": "Lx", + "y": "Ly", + "z": "Lz", +} + + +def write_problem_py_script(path, loading, solver_options): + + TEMPLATE = dedent( + """\ + import jax + import jax.numpy as np + import numpy as onp + import os + import time + + from jax_fem.solver import solver + from jax_fem.generate_mesh import box_mesh, Mesh + from jax_fem.utils import save_sol + + + from model import CrystalPlasticity + + os.environ["CUDA_VISIBLE_DEVICES"] = "2" + + case_name = "{case_name}" + + data_dir = os.path.join(os.path.dirname(__file__), "data") + vtk_dir = os.path.join(data_dir, f"vtk/{{case_name}}") + + + def load_mesh_file(): + mesh_data = onp.load("mesh.npz") + meshio_obj = box_mesh(*mesh_data["num_cells"], *mesh_data["domain"]) + mesh_obj = Mesh(meshio_obj.points, meshio_obj.cells_dict["hexahedron"], ele_type="HEX8") + return {{ + "mesh_obj": mesh_obj, + "cell_grain_indices": mesh_data["cell_grain_indices"], + "quaternions": mesh_data["quaternions"], + "grain_orientation_indices": mesh_data["grain_orientation_indices"], + }} + + + def problem(): + print(jax.lib.xla_bridge.get_backend().platform) + + ele_type = "HEX8" + + mesh_data = load_mesh_file() + mesh = mesh_data["mesh_obj"] + cell_grain_inds = mesh_data["cell_grain_indices"] + grain_oris_inds = mesh_data["grain_orientation_indices"] + cell_ori_inds = grain_oris_inds[cell_grain_inds] + quat = mesh_data["quaternions"] + + print(f"{{quat=!r}}") + print("cell_grain_inds", cell_grain_inds) + print("grain_oris_inds", grain_oris_inds) + print("cell_ori_inds", cell_ori_inds) + print("No. of total mesh points:", mesh.points.shape) + + Lx = np.max(mesh.points[:, 0]) + Ly = np.max(mesh.points[:, 1]) + Lz = np.max(mesh.points[:, 2]) + print(f"Domain size: {{Lx}}, {{Ly}}, {{Lz}}") + + displacements = np.linspace(0, {strain}*{direction}, {num_increments} + 1) + ts = np.linspace(0, {total_time}, {num_increments} + 1) + + ## Hu: Define index of points and faces + def corner(point): + flag_x = np.isclose(point[0], 0.0, atol=1e-5) + flag_y = np.isclose(point[1], 0.0, atol=1e-5) + flag_z = np.isclose(point[2], Lz, atol=1e-5) + return np.logical_and(np.logical_and(flag_x, flag_y), flag_z) + + def corner2(point): + flag_x = np.isclose(point[0], 0.0, atol=1e-5) + flag_y = np.isclose(point[1], 0.0, atol=1e-5) + flag_z = np.isclose(point[2], 0.0, atol=1e-5) + return np.logical_and(np.logical_and(flag_x, flag_y), flag_z) + + def corner3(point): + flag_x = np.isclose(point[0], Lx, atol=1e-5) + flag_y = np.isclose(point[1], 0.0, atol=1e-5) + flag_z = np.isclose(point[2], 0.0, atol=1e-5) + return np.logical_and(np.logical_and(flag_x, flag_y), flag_z) + + def corner4(point): + flag_x = np.isclose(point[0], Lx, atol=1e-5) + flag_y = np.isclose(point[1], 0.0, atol=1e-5) + flag_z = np.isclose(point[2], Lz, atol=1e-5) + return np.logical_and(np.logical_and(flag_x, flag_y), flag_z) + + def left(point): + return np.isclose(point[0], 0.0, atol=1e-5) + + def right(point): + return np.isclose(point[0], Lx, atol=1e-5) + + def front(point): + return np.isclose(point[1], 0.0, atol=1e-5) + + def back(point): + return np.isclose(point[1], Ly, atol=1e-5) + + def bottom(point): + return np.isclose(point[2], 0.0, atol=1e-5) + + def top(point): + return np.isclose(point[2], Lz, atol=1e-5) + + ## Hu: Define dirichlet B.C. + def zero_dirichlet_val(point): + return 0.0 + + def get_dirichlet_top(disp): + def val_fn(point): + return disp + + return val_fn + + {dirichlet_BCs} + + ## Hu: Define CPFEM problem on top of JAX-FEM + ## Xue, Tianju, et al. Computer Physics Communications 291 (2023): 108802. + problem = CrystalPlasticity( + mesh, + vec=3, + dim=3, + ele_type=ele_type, + dirichlet_bc_info=dirichlet_bc_info, + additional_info=(quat, cell_ori_inds), + ) + + sol_list = [np.zeros((problem.fes[0].num_total_nodes, problem.fes[0].vec))] + ## Hu: self.internal_vars = [Fp_inv_gp, slip_resistance_gp, slip_gp, rot_mats_gp] + params = problem.internal_vars + + results_to_save = [] + stress_plot = np.array([]) + stress_xx_plot = np.array([]) + stress_yy_plot = np.array([]) + stress_zz_plot = np.array([]) + von_mises_stress_plot = np.array([]) + + for i in range(len(ts) - 1): + problem.dt = ts[i + 1] - ts[i] + print( + f"\\nStep {{i + 1}} in {{len(ts) - 1}}, disp = {{displacements[i + 1]}}, dt = {{problem.dt}}" + ) + + ## Hu: Reset Dirichlet boundary conditions. + ## Hu: Useful when a time-dependent problem is solved, and at each iteration the boundary condition needs to be updated. + dirichlet_bc_info[-1][{dirichlet_ut_idx}] = get_dirichlet_top(displacements[i + 1]) + problem.fes[0].update_Dirichlet_boundary_conditions(dirichlet_bc_info) + + ## Hu: Set up internal variables of previous step for inner Newton's method + ## self.internal_vars = [Fp_inv_gp, slip_resistance_gp, slip_gp, rot_mats_gp] + problem.set_params(params) + + ## Hu: JAX-FEM's solver for outer Newton's method + ## solver(problem, solver_options={{}}) + ## Examples: + ## (1) solver_options = {{'jax_solver': {{}}}} + ## (2) solver_options = {{'umfpack_solver': {{}}}} + ## (3) solver_options = {{'petsc_solver': {{'ksp_type': 'bcgsl', 'pc_type': 'jacobi'}}, 'initial_guess': some_guess}} + sol_list = solver( + problem, solver_options={{"{solver_name}": {solver_options}, "initial_guess": sol_list}} + ) + + ## Hu: Post-processing for aacroscopic Cauchy stress of each cell + print(f"Computing stress...") + sigma_cell_data = problem.compute_avg_stress(sol_list[0], params)[:, :, :] + sigma_cell_xx = sigma_cell_data[:, 0, 0] + sigma_cell_yy = sigma_cell_data[:, 1, 1] + sigma_cell_zz = sigma_cell_data[:, 2, 2] + sigma_cell_xy = sigma_cell_data[:, 0, 1] + sigma_cell_xz = sigma_cell_data[:, 0, 2] + sigma_cell_yz = sigma_cell_data[:, 1, 2] + sigma_cell_von_mises_stress = ( + 0.5 + * ( + (sigma_cell_xx - sigma_cell_yy) ** 2.0 + + (sigma_cell_yy - sigma_cell_zz) ** 2.0 + + (sigma_cell_zz - sigma_cell_xx) ** 2.0 + ) + + +3.0 * (sigma_cell_xy**2.0 + sigma_cell_yz**2.0 + sigma_cell_xz**2.0) + ) ** 0.5 + + stress_xx_plot = np.append(stress_xx_plot, np.mean(sigma_cell_xx)) + stress_yy_plot = np.append(stress_yy_plot, np.mean(sigma_cell_yy)) + stress_zz_plot = np.append(stress_zz_plot, np.mean(sigma_cell_zz)) + von_mises_stress_plot = np.append( + von_mises_stress_plot, np.mean(sigma_cell_von_mises_stress) + ) + print( + f"Average Cauchy stress: stress_xx = {{stress_xx_plot[-1]}}, stress_yy = {{stress_yy_plot[-1]}}, stress_zz = {{stress_zz_plot[-1]}}, \\ + vM_stress = {{von_mises_stress_plot[-1]}}, max stress = {{np.max(sigma_cell_data)}}" + ) + + ## Hu: Update internal variables + ## self.internal_vars = [Fp_inv_gp, slip_resistance_gp, slip_gp, rot_mats_gp] + print(f"Updating int vars...") + params = problem.update_int_vars_gp(sol_list[0], params) + F_p_zz, slip_resistance_0, slip_0 = problem.inspect_interval_vars(params) + + ## Hu: Post-processing for visualization + vtk_path = os.path.join(vtk_dir, f"u_{{i:03d}}.vtu") + save_sol( + problem.fes[0], + sol_list[0], + vtk_path, + cell_infos=[ + ("cell_ori_inds", cell_ori_inds), + ("sigma_xx", sigma_cell_xx), + ("sigma_yy", sigma_cell_yy), + ("sigma_zz", sigma_cell_zz), + ("von_Mises_stress", sigma_cell_von_mises_stress), + ], + ) + + print("*************") + print("grain_oris_inds:\\n", grain_oris_inds) + print("stress_xx:\\n", onp.array(stress_xx_plot, order="F", dtype=onp.float64)) + print("stress_yy:\\n", onp.array(stress_yy_plot, order="F", dtype=onp.float64)) + print("stress_zz:\\n", onp.array(stress_zz_plot, order="F", dtype=onp.float64)) + print( + "von_mises_stress:\\n", + onp.array(von_mises_stress_plot, order="F", dtype=onp.float64), + ) + print("*************") + + + if __name__ == "__main__": + start_time = time.time() + problem() + end_time = time.time() + run_time = end_time - start_time + print("Simulation time:", run_time) + """ + ) + + DISP_COMP_LOOKUP = {"u_x": 0, "u_y": 1, "u_z": 2} + + u_t_idx = None # the of the BC where we specify the required displacement at time t + location_funcs = [] + vector_comps = [] + value_funcs = [] + for dbc in loading["dirichlet_BCs"]: + for disp_comp, value in dbc["values"].items(): + + if value == 0: + # a function that takes a point and returns zero: + value = "zero_dirichlet_val" + elif value.lower() == "u(t)": + # a function that takes a point and returns the required displacement at + # time t: + value = "get_dirichlet_top(displacements[0])" + u_t_idx = len(location_funcs) + + location_funcs.append(dbc["points"]) + vector_comps.append(DISP_COMP_LOOKUP[disp_comp]) + value_funcs.append(value) + + DBCs_TEMPLATE = dedent( + """\ + dirichlet_bc_info = [ + {location_functions_str}, + {vector_components_str}, + {value_functions_str}, + ] + """ + ) + INDENT = " " + DBCs = DBCs_TEMPLATE.format( + location_functions_str=indent("[" + ", ".join(location_funcs) + "]", INDENT), + vector_components_str=indent( + "[" + ", ".join(str(i) for i in vector_comps) + "]", INDENT + ), + value_functions_str=indent("[" + ", ".join(value_funcs) + "]", INDENT), + ) + + solver_name = solver_options.pop("name") + with path.open("wt") as fh: + fh.write( + TEMPLATE.format( + case_name="polycrystal", + strain=loading["strain"], + num_increments=loading["num_increments"], + direction=_DIR_LOOKUP[loading["direction"].lower()], + total_time=loading["total_time"], + dirichlet_BCs=indent(DBCs, INDENT), + dirichlet_ut_idx=u_t_idx, + solver_name=solver_name, + solver_options=solver_options, + ) + ) + + +def __apply_for_values( + dir_BC, func_name, loc_funcs, vec_comps, value_funcs +) -> tuple[str, int | None]: + out = "" + dirichlet_update_idx = None + for field_dir, value in dir_BC.value.items(): + loc_funcs.append(func_name) + vec_comps.append(BoundaryCondition._DIRS.index(field_dir)) + if value == 0: + # a function that takes a point and returns zero: + value_str = "zero_value" + if value_str not in value_funcs: + out += dedent( + """\ + def zero_value(point): + return 0 + + """ + ) + elif value.lower() == "u(t)": + # a function that takes a point and returns the required + # displacement at time t: + dirichlet_update_idx = len(value_funcs) + value_str = "constant_value(displacements[0])" + if value_str not in value_funcs: + out += dedent( + """\ + def constant_value(value): + def val_fn(point): + return value + return val_fn + + """ + ) + + value_funcs.append(value_str) + return out, dirichlet_update_idx + + +def create_JAX_CPFEM_boundary_conditions_code( + load_case, domain_size: Sequence[float | str] +) -> str: + """ + Create a string containing Python code that can be used to define this + load case (represented with Dirichlet boundary conditions) when using JAX-CPFEM. + """ + + dir_BCs = load_case.to_dirichlet_BCs() + step = load_case.steps[0] + + domain_size_arr = np.asarray(domain_size) + func_lst = set() + out = "" + dirichlet_update_idx = None + loc_funcs = [] + vec_comps = [] + value_funcs = [] + for dir_BC in dir_BCs: + corners = dir_BC.corners + faces = dir_BC.faces + num_regions = len(corners) if corners else len(faces) + for corner in corners or (): + func_name = BoundaryCondition._JAX_CPFEM_BC_FUNC_CORNER_MAP[tuple(corner)] + if func_name not in func_lst: + func_lst.update(func_name) + # apply domain size: + corner_arr = np.asarray(corner) + is_max = corner_arr == 1 + corner_arr[is_max] = domain_size_arr[is_max] + out += BoundaryCondition.get_corner_location_func_str(corner_arr) + out_i, update_idx = __apply_for_values( + dir_BC, func_name, loc_funcs, vec_comps, value_funcs + ) + out += out_i + + if update_idx is not None: + dirichlet_update_idx = update_idx + + for face in faces or (): + func_name = BoundaryCondition._JAX_CPFEM_BC_FUNC_FACE_MAP[face] + if func_name not in func_lst: + func_lst.update(func_name) + out += BoundaryCondition.get_face_location_func_str(face, domain_size) + out_i, update_idx = __apply_for_values( + dir_BC, func_name, loc_funcs, vec_comps, value_funcs + ) + out += out_i + if update_idx is not None: + dirichlet_update_idx = update_idx + + out += dedent( + """\ + displacements = np.linspace(0, {strain}*{domain_i}, {num_increments} + 1) + ts = np.linspace(0, {total_time}, {num_increments} + 1) + + dirichlet_bc_info = [ + {location_functions_str}, + {vector_components_str}, + {value_functions_str}, + ] + """ + ).format( + strain=step.strain, + domain_i=_DIR_LOOKUP[step.direction], + num_increments=step.num_increments, + total_time=step.total_time, + location_functions_str="[" + ", ".join(loc_funcs) + "]", + vector_components_str="[" + ", ".join(str(i) for i in vec_comps) + "]", + value_functions_str="[" + ", ".join(value_funcs) + "]", + ) + + return out, dirichlet_update_idx diff --git a/matflow/data/scripts/jax_cpfem/write_slip_systems.py b/matflow/data/scripts/jax_cpfem/write_slip_systems.py new file mode 100644 index 00000000..7065cada --- /dev/null +++ b/matflow/data/scripts/jax_cpfem/write_slip_systems.py @@ -0,0 +1,26 @@ +from textwrap import dedent + + +def write_slip_systems(path, slip_systems): + # slip plane normals, followed by slip directions + # FCC (?) + + # TEMPLATE = dedent( + # """\ + # 1 1 -1 0 1 1 + # 1 1 -1 1 0 1 + # 1 1 -1 1 -1 0 + # 1 -1 -1 0 1 -1 + # 1 -1 -1 1 0 1 + # 1 -1 -1 1 1 0 + # 1 -1 1 0 1 1 + # 1 -1 1 1 0 -1 + # 1 -1 1 1 1 0 + # 1 1 1 0 1 -1 + # 1 1 1 1 0 -1 + # 1 1 1 1 -1 0 + # """ + # ) + with path.open("wt") as fh: + for slip_plane, slip_dir in slip_systems: + fh.write(" ".join(str(i) for i in [*slip_plane, *slip_dir]) + "\n") diff --git a/matflow/data/workflows/simulate_JAX_CPFEM_copper.yaml b/matflow/data/workflows/simulate_JAX_CPFEM_copper.yaml new file mode 100644 index 00000000..febf5cd2 --- /dev/null +++ b/matflow/data/workflows/simulate_JAX_CPFEM_copper.yaml @@ -0,0 +1,150 @@ +template_components: + environments: + - name: jax_cpfem + executables: + - label: run_jax_cpfem + instances: + - command: docker run --rm -it -v "${PWD}:/app" jax-cpfem + parallel_mode: null + num_cores: 1 + + command_files: + - label: jax_cpfem_mesh_data + name: + name: mesh.npz + doc: > + Information that we can use to generate a JAX-FEM hexahedral ("box") `Mesh` + object, and associate each grain with quaternion orientations. + - label: jax_cpfem_model_script + name: + name: model.py + - label: jax_cpfem_problem_script + name: + name: problem.py + - label: jax_cpfem_slip_systems + name: + name: input_slip_sys.txt + doc: Slip systems to model. + - label: jax_cpfem_stdout + name: + name: stdout.log + doc: Standard output stream from a JAX-CPFEM simulation. + - label: jax_cpfem_stderr + name: + name: stderr.log + doc: Standard error stream from a JAX-CPFEM simulation. + + task_schemas: + - objective: simulate_VE_loading + implementation: JAX_CPFEM + inputs: + - parameter: volume_element + - parameter: material_parameters + - parameter: slip_systems + - parameter: loading + - parameter: solver_options + - parameter: numerics + outputs: + - parameter: CP_outputs # temp naming + actions: + - requires_dir: true + environments: + - scope: + type: main + environment: jax_cpfem + - scope: + type: processing + environment: python_env + input_file_generators: + - input_file: jax_cpfem_mesh_data + from_inputs: [volume_element] + script: <> + - input_file: jax_cpfem_model_script + from_inputs: [material_parameters, numerics] + script: <> + - input_file: jax_cpfem_problem_script + from_inputs: [loading, solver_options] + script: <> + - input_file: jax_cpfem_slip_systems + from_inputs: [slip_systems] + script: <> + commands: + - command: <> problem + stdout: stdout.log + stderr: stderr.log + output_file_parsers: + CP_outputs: + from_files: [jax_cpfem_stdout] + script: <> + +# TODO: +# - CP sim comparison DAMASK? + +tasks: + - schema: generate_volume_element_from_voronoi + inputs: + periodic: false + VE_grid_size: [10, 10, 10] + orientations: + data: + - [1, 0, 0, 0] + unit_cell_alignment: { x: a, y: b, z: c } + representation: + type: quaternion + quat_order: scalar_vector + microstructure_seeds::from_random: + num_seeds: 1 + box_size: [0.1, 0.1, 0.1] + phase_label: copper + + - schema: visualise_VE_VTK + + - schema: simulate_VE_loading_JAX_CPFEM + inputs: + material_parameters: + C11: 1.684e5 # MPa + C12: 1.214e5 # MPa + C44: 0.754e5 # MPa + r: 1 # latent hardening + gss_initial: 60.8 # initial flow stress, in MPa + h: 541.5 # initial hardening, in MPa + t_sat: 109.8 # saturation strength, in MPa + gss_a: 2.5 # 'a' in Kalidini model + ao: 0.001 # reference strain rate + xm_d: 10 # rate sensitivity exponent (denominator) + solver_options: + name: jax_solver + # name: petsc_solver + # ksp_type: bcgsl + # pc_type: jacobi + slip_systems: # plane and direction + - [[1, 1, -1], [0, 1, 1]] + - [[1, 1, -1], [1, 0, 1]] + - [[1, 1, -1], [1, -1, 0]] + - [[1, -1, -1], [0, 1, -1]] + - [[1, -1, -1], [1, 0, 1]] + - [[1, -1, -1], [1, 1, 0]] + - [[1, -1, 1], [0, 1, 1]] + - [[1, -1, 1], [1, 0, -1]] + - [[1, -1, 1], [1, 1, 0]] + - [[1, 1, 1], [0, 1, -1]] + - [[1, 1, 1], [1, 0, -1]] + - [[1, 1, 1], [1, -1, 0]] + loading: + strain: 0.01 + direction: x # reproduces example, but should be z? + total_time: 0.1 + num_increments: 10 + dirichlet_BCs: + - points: bottom + values: + u_x: 0 + u_y: 0 + u_z: 0 + - points: top + values: + u_x: 0 + u_y: 0 + u_z: u(t) + numerics: + newton_max_sub_step: 5 diff --git a/matflow/data/workflows/simulate_JAX_CPFEM_steel.yaml b/matflow/data/workflows/simulate_JAX_CPFEM_steel.yaml new file mode 100644 index 00000000..a978d5c2 --- /dev/null +++ b/matflow/data/workflows/simulate_JAX_CPFEM_steel.yaml @@ -0,0 +1,175 @@ +template_components: + environments: + - name: jax_cpfem + executables: + - label: run_jax_cpfem + instances: + - command: docker run --rm -it -v "${PWD}:/app" jax-cpfem + parallel_mode: null + num_cores: 1 + + command_files: + - label: jax_cpfem_mesh_data + name: + name: mesh.npz + doc: > + Information that we can use to generate a JAX-FEM hexahedral ("box") `Mesh` + object, and associate each grain with quaternion orientations. + - label: jax_cpfem_model_script + name: + name: model.py + - label: jax_cpfem_problem_script + name: + name: problem.py + - label: jax_cpfem_slip_systems + name: + name: input_slip_sys.txt + doc: Slip systems to model. + - label: jax_cpfem_stdout + name: + name: stdout.log + doc: Standard output stream from a JAX-CPFEM simulation. + - label: jax_cpfem_stderr + name: + name: stderr.log + doc: Standard error stream from a JAX-CPFEM simulation. + + task_schemas: + - objective: simulate_VE_loading + implementation: JAX_CPFEM + inputs: + - parameter: volume_element + - parameter: material_parameters + - parameter: loading + - parameter: solver_options + - parameter: numerics + actions: + - requires_dir: true + environments: + - scope: + type: main + environment: jax_cpfem + - scope: + type: processing + environment: python_env + input_file_generators: + - input_file: jax_cpfem_mesh_data + from_inputs: [volume_element] + script: <> + - input_file: jax_cpfem_model_script + from_inputs: [material_parameters, numerics] + script: <> + - input_file: jax_cpfem_problem_script + from_inputs: [loading, solver_options] + script: <> + - input_file: jax_cpfem_slip_systems + from_inputs: [] + script: <> + commands: + - command: <> problem + stdout: stdout.log + stderr: stderr.log + +# TODO: +# - test mesh equivalence in polycrystal example, from provided gmsh .msh file +# - (can just look at first sim output VTU file (contains cell oris map) + +tasks: + - schema: generate_volume_element_from_voronoi + inputs: + periodic: false + VE_grid_size: [8, 8, 8] + orientations: + data: + - [ + -0.44841562434329973, + 0.16860994035907192, + -0.3509983118612716, + 0.8045460216342235, + ] + - [ + -0.1721519777399991, + 0.1221784375558591, + -0.7486299266954833, + -0.628481788767607, + ] + - [ + 0.1755902645854752, + -0.8369269701982096, + 0.21157163334810386, + 0.4732428018470682, + ] + - [ + -0.8663740014241987, + -0.187241901559973, + 0.4235298570897932, + -0.18697331389780547, + ] + - [ + -0.9366201723048194, + 0.07470626885605275, + 0.34088334168529333, + 0.03098666788741778, + ] + - [ + -0.462754124567833, + -0.3500710695060419, + -0.3866988406362011, + -0.7167795150120938, + ] + - [ + -0.6587757263479669, + -0.16019499696083303, + -0.3423669060109474, + 0.6504898208211397, + ] + - [ + 0.1985059035197109, + 0.6384482594496135, + 0.05206231874156126, + 0.7418010118898696, + ] + unit_cell_alignment: { x: a, y: b, z: c } + representation: + type: quaternion + quat_order: scalar_vector + microstructure_seeds::from_random: + num_seeds: 8 + box_size: [0.016, 0.016, 0.016] + phase_label: 304_steel + + - schema: visualise_VE_VTK + + - schema: simulate_VE_loading_JAX_CPFEM + inputs: + material_parameters: + C11: 2.622e5 # MPa + C12: 1.120e5 # MPa + C44: 0.746e5 # MPa + r: 1 # latent hardening + gss_initial: 90 # initial flow stress, in MPa + h: 392.9772 # initial hardening, in MPa + t_sat: 7295.1754 # saturation strength, in MPa + gss_a: 8 # 'a' in Kalidini model + ao: 0.001 # reference strain rate + xm_d: 120 # rate sensitivity exponent (denominator) + solver_options: + name: jax_solver + numerics: + newton_max_sub_step: 8 + loading: + strain: 0.01 + direction: x # reproduces example, but should be z? + total_time: 0.1 + num_increments: 50 + dirichlet_BCs: + - points: corner + values: + u_x: 0 + u_y: 0 + - points: bottom + values: + u_z: 0 + - points: top + values: + u_z: u(t) From 4386717d870815ba18e427737fd3da06bc653c49 Mon Sep 17 00:00:00 2001 From: Adam Date: Wed, 18 Jun 2025 16:15:55 +0100 Subject: [PATCH 05/11] fix: incorrect type and missing import --- matflow/param_classes/boundary_conditions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/matflow/param_classes/boundary_conditions.py b/matflow/param_classes/boundary_conditions.py index 6bf0c507..379680df 100644 --- a/matflow/param_classes/boundary_conditions.py +++ b/matflow/param_classes/boundary_conditions.py @@ -7,6 +7,7 @@ from collections.abc import Mapping, Sequence from typing import Literal +from textwrap import dedent from typing_extensions import Final, Self @@ -55,7 +56,9 @@ class BoundaryCondition: "+z": "face_pos_z", "-z": "face_neg_z", } - _LOCATION_FUNC_CORNER_BODY_TEMPLATE = "return np.allclose(point, {corner}, atol=1e-5)" + _LOCATION_FUNC_CORNER_BODY_TEMPLATE = ( + "return np.allclose(point, np.asarray({corner}), atol=1e-5)" + ) _LOCATION_FUNC_FACE_BODY_TEMPLATE = ( "return np.isclose(point[{comp_idx}], {value}, atol=1e-5)" ) From 0e934e4b5c6a313c9659ef67658ae49bf7b7cd4d Mon Sep 17 00:00:00 2001 From: Adam Date: Wed, 18 Jun 2025 16:18:40 +0100 Subject: [PATCH 06/11] feat: add `LoadStep.one_dimensional` for uniaxial 1D loading with no non-axial deformation --- matflow/param_classes/load.py | 196 ++++++++++++++++++++++++++++------ 1 file changed, 163 insertions(+), 33 deletions(-) diff --git a/matflow/param_classes/load.py b/matflow/param_classes/load.py index a80a224f..9aae04c9 100644 --- a/matflow/param_classes/load.py +++ b/matflow/param_classes/load.py @@ -279,6 +279,61 @@ def example_uniaxial(cls) -> Self: target_def_grad_rate=rate, ) + @classmethod + def __pre_process_uniaxial_like_method_args( + cls, + direction: str, + target_strain: float | None = None, + target_strain_rate: float | None = None, + target_def_grad: float | None = None, + target_def_grad_rate: float | None = None, + ) -> tuple[int, float | None, float | None, float | None, float | None]: + """Perform common processing on a subset of arguments for `uniaxial` and + `one_dimensional` load case methods.""" + + # Validation: + msg = ( + "Specify either `target_strain`, `target_strain_rate`, " + "``target_def_grad` or target_def_grad_rate`." + ) + strain_arg = ( + target_strain, + target_strain_rate, + target_def_grad, + target_def_grad_rate, + ) + if sum(s is not None for s in strain_arg) != 1: + raise ValueError(msg) + + # convert strain (rate) to deformation gradient (rate) components, and ensure both + # strain(_rate) and def_grad(_rate) are populated: + if target_strain is not None: + target_def_grad = 1 + target_strain + elif target_def_grad is not None: + target_strain = target_def_grad - 1 + + if target_strain_rate is not None: + target_def_grad_rate = target_strain_rate + elif target_def_grad_rate is not None: + target_strain_rate = target_def_grad_rate + + try: + loading_dir_idx = cls._DIR_IDX.index(direction) + except ValueError: + msg = ( + f'Loading direction "{direction}" not allowed. It should be one of "x", ' + f'"y" or "z".' + ) + raise ValueError(msg) + + return ( + loading_dir_idx, + target_strain, + target_strain_rate, + target_def_grad, + target_def_grad_rate, + ) + @classmethod def uniaxial( cls, @@ -318,6 +373,19 @@ def uniaxial( """ _method_name = "uniaxial" + ( + loading_dir_idx, + target_strain, + target_strain_rate, + target_def_grad, + target_def_grad_rate, + ) = cls.__pre_process_uniaxial_like_method_args( + direction, + target_strain, + target_strain_rate, + target_def_grad, + target_def_grad_rate, + ) _method_args = { "total_time": total_time, "num_increments": num_increments, @@ -329,52 +397,105 @@ def uniaxial( "dump_frequency": dump_frequency, } - # Validation: - msg = ( - "Specify either `target_strain`, `target_strain_rate`, " - "``target_def_grad` or target_def_grad_rate`." + if target_def_grad_rate is not None: + def_grad_val = target_def_grad_rate + else: + def_grad_val = target_def_grad + + dg_arr = np.ma.masked_array(np.zeros((3, 3)), mask=np.eye(3)) + stress_arr = np.ma.masked_array(np.zeros((3, 3)), mask=np.logical_not(np.eye(3))) + + dg_arr[loading_dir_idx, loading_dir_idx] = def_grad_val + dg_arr.mask[loading_dir_idx, loading_dir_idx] = False + stress_arr.mask[loading_dir_idx, loading_dir_idx] = True + + def_grad_aim = dg_arr if target_def_grad is not None else None + def_grad_rate = dg_arr if target_def_grad_rate is not None else None + + obj = cls( + direction=direction, + total_time=total_time, + num_increments=num_increments, + target_def_grad=def_grad_aim, + target_def_grad_rate=def_grad_rate, + stress=stress_arr, + dump_frequency=dump_frequency, ) - strain_arg = ( + return obj._remember_name_args(_method_name, _method_args) + + @classmethod + def one_dimensional( + cls, + total_time: float | int, + num_increments: int, + direction: str, + target_strain: float | None = None, + target_strain_rate: float | None = None, + target_def_grad_rate: float | None = None, + target_def_grad: float | None = None, + dump_frequency: int = 1, + ) -> Self: + """ + Generate a load step that deforms in only one dimension, such that the geometrical + result is like that for a material with a Poisson's ratio (nu) of zero. + + Parameters + ---------- + total_time + Total simulation time. + num_increments + Number of simulation increments. + direction : str + A single character, "x", "y" or "z", representing the loading direction. + target_def_grad : float + Target deformation gradient to achieve along the loading direction component. + target_strain: float + Target engineering strain to achieve along the loading direction. Specify at + most one of `target_strain` and `target_def_grad`. + target_def_grad_rate : float + Target deformation gradient rate to achieve along the loading direction + component. + target_strain_rate: float + Target engineering strain rate to achieve along the loading direction. Specify + at most one of `target_strain_rate` and `target_def_grad_rate`. + dump_frequency : int, optional + By default, 1, meaning results are written out every increment. + """ + + _method_name = "one_dimensional" + ( + loading_dir_idx, + target_strain, + target_strain_rate, + target_def_grad, + target_def_grad_rate, + ) = cls.__pre_process_uniaxial_like_method_args( + direction, target_strain, target_strain_rate, target_def_grad, target_def_grad_rate, ) - if sum(s is not None for s in strain_arg) != 1: - raise ValueError(msg) - - # convert strain (rate) to deformation gradient (rate) components, and ensure both - # strain(_rate) and def_grad(_rate) are populated: - if target_strain is not None: - target_def_grad = 1 + target_strain - elif target_def_grad is not None: - target_strain = target_def_grad - 1 - - if target_strain_rate is not None: - target_def_grad_rate = target_strain_rate - elif target_def_grad_rate is not None: - target_strain_rate = target_def_grad_rate + _method_args = { + "total_time": total_time, + "num_increments": num_increments, + "direction": direction, + "target_strain": target_strain, + "target_strain_rate": target_strain_rate, + "target_def_grad": target_def_grad, + "target_def_grad_rate": target_def_grad_rate, + "dump_frequency": dump_frequency, + } if target_def_grad_rate is not None: def_grad_val = target_def_grad_rate else: def_grad_val = target_def_grad - try: - loading_dir_idx = cls._DIR_IDX.index(direction) - except ValueError: - msg = ( - f'Loading direction "{direction}" not allowed. It should be one of "x", ' - f'"y" or "z".' - ) - raise ValueError(msg) - - dg_arr = np.ma.masked_array(np.zeros((3, 3)), mask=np.eye(3)) - stress_arr = np.ma.masked_array(np.zeros((3, 3)), mask=np.logical_not(np.eye(3))) - + dg_arr = np.ma.masked_array(np.zeros((3, 3)), mask=np.zeros((3, 3))) dg_arr[loading_dir_idx, loading_dir_idx] = def_grad_val - dg_arr.mask[loading_dir_idx, loading_dir_idx] = False - stress_arr.mask[loading_dir_idx, loading_dir_idx] = True + + stress_arr = np.ma.masked_array(np.zeros((3, 3)), mask=np.ones((3, 3))) def_grad_aim = dg_arr if target_def_grad is not None else None def_grad_rate = dg_arr if target_def_grad_rate is not None else None @@ -1020,6 +1141,15 @@ def uniaxial(cls, **kwargs) -> Self: """ return cls(steps=[LoadStep.uniaxial(**kwargs)]) + @classmethod + def one_dimensional(cls, **kwargs) -> Self: + """A single-step one-dimensional load case. + + See :py:meth:`~LoadStep.one_dimensional` for argument documentation. + + """ + return cls(steps=[LoadStep.one_dimensional(**kwargs)]) + @classmethod def biaxial(cls, **kwargs) -> Self: """A single-step biaxial load case. From 9c6711e8d0b16e8f41dd362f1ae0083d8cafc1b7 Mon Sep 17 00:00:00 2001 From: Adam Date: Wed, 18 Jun 2025 20:53:04 +0100 Subject: [PATCH 07/11] fix: add `BoundaryCondition.one_dimensional` and update load cases in JAX-CPFEM examples --- .../workflows/simulate_JAX_CPFEM_copper.yaml | 96 ++----------------- .../workflows/simulate_JAX_CPFEM_steel.yaml | 86 ++--------------- matflow/param_classes/boundary_conditions.py | 15 ++- matflow/param_classes/load.py | 9 +- 4 files changed, 36 insertions(+), 170 deletions(-) diff --git a/matflow/data/workflows/simulate_JAX_CPFEM_copper.yaml b/matflow/data/workflows/simulate_JAX_CPFEM_copper.yaml index febf5cd2..db5a10a6 100644 --- a/matflow/data/workflows/simulate_JAX_CPFEM_copper.yaml +++ b/matflow/data/workflows/simulate_JAX_CPFEM_copper.yaml @@ -8,78 +8,6 @@ template_components: parallel_mode: null num_cores: 1 - command_files: - - label: jax_cpfem_mesh_data - name: - name: mesh.npz - doc: > - Information that we can use to generate a JAX-FEM hexahedral ("box") `Mesh` - object, and associate each grain with quaternion orientations. - - label: jax_cpfem_model_script - name: - name: model.py - - label: jax_cpfem_problem_script - name: - name: problem.py - - label: jax_cpfem_slip_systems - name: - name: input_slip_sys.txt - doc: Slip systems to model. - - label: jax_cpfem_stdout - name: - name: stdout.log - doc: Standard output stream from a JAX-CPFEM simulation. - - label: jax_cpfem_stderr - name: - name: stderr.log - doc: Standard error stream from a JAX-CPFEM simulation. - - task_schemas: - - objective: simulate_VE_loading - implementation: JAX_CPFEM - inputs: - - parameter: volume_element - - parameter: material_parameters - - parameter: slip_systems - - parameter: loading - - parameter: solver_options - - parameter: numerics - outputs: - - parameter: CP_outputs # temp naming - actions: - - requires_dir: true - environments: - - scope: - type: main - environment: jax_cpfem - - scope: - type: processing - environment: python_env - input_file_generators: - - input_file: jax_cpfem_mesh_data - from_inputs: [volume_element] - script: <> - - input_file: jax_cpfem_model_script - from_inputs: [material_parameters, numerics] - script: <> - - input_file: jax_cpfem_problem_script - from_inputs: [loading, solver_options] - script: <> - - input_file: jax_cpfem_slip_systems - from_inputs: [slip_systems] - script: <> - commands: - - command: <> problem - stdout: stdout.log - stderr: stderr.log - output_file_parsers: - CP_outputs: - from_files: [jax_cpfem_stdout] - script: <> - -# TODO: -# - CP sim comparison DAMASK? - tasks: - schema: generate_volume_element_from_voronoi inputs: @@ -99,6 +27,14 @@ tasks: - schema: visualise_VE_VTK + - schema: define_load_case + inputs: + load_case::one_dimensional: + num_increments: 10 + total_time: 0.1 + target_strain: 0.01 + direction: x + - schema: simulate_VE_loading_JAX_CPFEM inputs: material_parameters: @@ -130,21 +66,5 @@ tasks: - [[1, 1, 1], [0, 1, -1]] - [[1, 1, 1], [1, 0, -1]] - [[1, 1, 1], [1, -1, 0]] - loading: - strain: 0.01 - direction: x # reproduces example, but should be z? - total_time: 0.1 - num_increments: 10 - dirichlet_BCs: - - points: bottom - values: - u_x: 0 - u_y: 0 - u_z: 0 - - points: top - values: - u_x: 0 - u_y: 0 - u_z: u(t) numerics: newton_max_sub_step: 5 diff --git a/matflow/data/workflows/simulate_JAX_CPFEM_steel.yaml b/matflow/data/workflows/simulate_JAX_CPFEM_steel.yaml index a978d5c2..7170e993 100644 --- a/matflow/data/workflows/simulate_JAX_CPFEM_steel.yaml +++ b/matflow/data/workflows/simulate_JAX_CPFEM_steel.yaml @@ -8,68 +8,6 @@ template_components: parallel_mode: null num_cores: 1 - command_files: - - label: jax_cpfem_mesh_data - name: - name: mesh.npz - doc: > - Information that we can use to generate a JAX-FEM hexahedral ("box") `Mesh` - object, and associate each grain with quaternion orientations. - - label: jax_cpfem_model_script - name: - name: model.py - - label: jax_cpfem_problem_script - name: - name: problem.py - - label: jax_cpfem_slip_systems - name: - name: input_slip_sys.txt - doc: Slip systems to model. - - label: jax_cpfem_stdout - name: - name: stdout.log - doc: Standard output stream from a JAX-CPFEM simulation. - - label: jax_cpfem_stderr - name: - name: stderr.log - doc: Standard error stream from a JAX-CPFEM simulation. - - task_schemas: - - objective: simulate_VE_loading - implementation: JAX_CPFEM - inputs: - - parameter: volume_element - - parameter: material_parameters - - parameter: loading - - parameter: solver_options - - parameter: numerics - actions: - - requires_dir: true - environments: - - scope: - type: main - environment: jax_cpfem - - scope: - type: processing - environment: python_env - input_file_generators: - - input_file: jax_cpfem_mesh_data - from_inputs: [volume_element] - script: <> - - input_file: jax_cpfem_model_script - from_inputs: [material_parameters, numerics] - script: <> - - input_file: jax_cpfem_problem_script - from_inputs: [loading, solver_options] - script: <> - - input_file: jax_cpfem_slip_systems - from_inputs: [] - script: <> - commands: - - command: <> problem - stdout: stdout.log - stderr: stderr.log - # TODO: # - test mesh equivalence in polycrystal example, from provided gmsh .msh file # - (can just look at first sim output VTU file (contains cell oris map) @@ -140,6 +78,14 @@ tasks: - schema: visualise_VE_VTK + - schema: define_load_case + inputs: + load_case::uniaxial: + num_increments: 10 + total_time: 0.1 + target_strain: 0.01 + direction: x + - schema: simulate_VE_loading_JAX_CPFEM inputs: material_parameters: @@ -157,19 +103,3 @@ tasks: name: jax_solver numerics: newton_max_sub_step: 8 - loading: - strain: 0.01 - direction: x # reproduces example, but should be z? - total_time: 0.1 - num_increments: 50 - dirichlet_BCs: - - points: corner - values: - u_x: 0 - u_y: 0 - - points: bottom - values: - u_z: 0 - - points: top - values: - u_z: u(t) diff --git a/matflow/param_classes/boundary_conditions.py b/matflow/param_classes/boundary_conditions.py index 379680df..86cf67dd 100644 --- a/matflow/param_classes/boundary_conditions.py +++ b/matflow/param_classes/boundary_conditions.py @@ -89,7 +89,7 @@ def __repr__(self): ) @classmethod - def uniaxial_tension(cls, direction: Literal["x", "y", "z"]) -> list[Self]: + def uniaxial(cls, direction: Literal["x", "y", "z"]) -> list[Self]: """Generate a list of boundary conditions that correspond to uniaxial loading along the specified direction.""" non_axial_dirs = sorted(list(set(cls._DIRS).difference(direction))) @@ -99,6 +99,19 @@ def uniaxial_tension(cls, direction: Literal["x", "y", "z"]) -> list[Self]: cls(faces=(f"+{direction}",), value={direction: "u(t)"}), ] + @classmethod + def one_dimensional(cls, direction: Literal["x", "y", "z"]) -> list[Self]: + """Generate a list of boundary conditions that correspond to one-dimensional + loading along the specified direction (no deformation in non-axial directions).""" + non_axial_dirs = sorted(list(set(cls._DIRS).difference(direction))) + return [ + cls(faces=(f"-{direction}",), value={"x": 0, "y": 0, "z": 0}), + cls( + faces=(f"+{direction}",), + value={non_axial_dirs[0]: 0, non_axial_dirs[1]: 0, direction: "u(t)"}, + ), + ] + @staticmethod def get_corner_location_func_str(corner) -> str: """Generate Python code that defines a function that return True if the provided diff --git a/matflow/param_classes/load.py b/matflow/param_classes/load.py index 9aae04c9..55e7495d 100644 --- a/matflow/param_classes/load.py +++ b/matflow/param_classes/load.py @@ -85,6 +85,7 @@ class LoadStep(ParameterValue): """ _DIR_IDX: Final[tuple[str, ...]] = ("x", "y", "z") + _SUPPORTS_SCALAR_STRAIN: Final[tuple[str]] = ("uniaxial", "one_dimensional") def __init__( self, @@ -221,7 +222,7 @@ def strain(self) -> float | None: For a limited subset of load step types (e.g. uniaxial), return the scalar target strain. """ - if self.type in ("uniaxial",): + if self.type in self._SUPPORTS_SCALAR_STRAIN: return self.method_args["target_strain"] @property @@ -230,7 +231,7 @@ def strain_rate(self) -> float | None: For a limited subset of load step types (e.g. uniaxial), return the scalar target strain rate. """ - if self.type in ("uniaxial",): + if self.type in self._SUPPORTS_SCALAR_STRAIN: return self.method_args["target_strain_rate"] def __repr__(self) -> str: @@ -253,7 +254,9 @@ def to_dirichlet_BCs(self) -> list[BoundaryCondition]: into Dirichlet boundary conditions.""" if self.type == "uniaxial": - return BoundaryCondition.uniaxial_tension(self.direction) + return BoundaryCondition.uniaxial(self.direction) + elif self.type == "one_dimensional": + return BoundaryCondition.one_dimensional(self.direction) raise NotImplementedError( "Cannot express this load step in terms of boundary conditions." ) From 38dc5f11f500c70fdb2d596b269c5485535a4d47 Mon Sep 17 00:00:00 2001 From: Adam Date: Wed, 18 Jun 2025 21:11:59 +0100 Subject: [PATCH 08/11] fix: typing --- matflow/param_classes/boundary_conditions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/matflow/param_classes/boundary_conditions.py b/matflow/param_classes/boundary_conditions.py index 86cf67dd..22889ac6 100644 --- a/matflow/param_classes/boundary_conditions.py +++ b/matflow/param_classes/boundary_conditions.py @@ -5,6 +5,8 @@ """ +from __future__ import annotations + from collections.abc import Mapping, Sequence from typing import Literal from textwrap import dedent From 3be8f1655027cd82938b2b3ca46f5b7255d31e3b Mon Sep 17 00:00:00 2001 From: Adam Date: Wed, 18 Jun 2025 21:35:25 +0100 Subject: [PATCH 09/11] fix: remove environment definitions from JAX-CPFEM workflows --- .../workflows/simulate_JAX_CPFEM_DAMASK.yaml | 117 ++++++++++++++++++ .../workflows/simulate_JAX_CPFEM_copper.yaml | 10 -- .../workflows/simulate_JAX_CPFEM_steel.yaml | 10 -- 3 files changed, 117 insertions(+), 20 deletions(-) create mode 100644 matflow/data/workflows/simulate_JAX_CPFEM_DAMASK.yaml diff --git a/matflow/data/workflows/simulate_JAX_CPFEM_DAMASK.yaml b/matflow/data/workflows/simulate_JAX_CPFEM_DAMASK.yaml new file mode 100644 index 00000000..8e14a0fb --- /dev/null +++ b/matflow/data/workflows/simulate_JAX_CPFEM_DAMASK.yaml @@ -0,0 +1,117 @@ +tasks: + - schema: generate_volume_element_from_voronoi + inputs: + periodic: false + VE_grid_size: [10, 10, 10] + orientations: + data: + - [1, 0, 0, 0] + unit_cell_alignment: { x: a, y: b, z: c } + representation: + type: quaternion + quat_order: scalar_vector + microstructure_seeds::from_random: + num_seeds: 1 + box_size: [0.1, 0.1, 0.1] + phase_label: copper + + - schema: visualise_VE_VTK + + - schema: define_load_case + inputs: + load_case::uniaxial: + total_time: 0.1 + num_increments: 10 + direction: z + target_strain: 0.01 + dump_frequency: 1 + + - schema: simulate_VE_loading_damask + inputs: + homogenization: + SX: + mechanical: { type: "pass" } + N_constituents: 1 + damask_phases: + copper: # TODO: # change to Cu material parameters! + lattice: cF + mechanical: + output: [F, P, F_e, F_p, L_p, O] + elastic: + type: Hooke + C_11: 106750000000 + C_12: 60410000000 + C_44: 28340000000 + plastic: + type: phenopowerlaw + N_sl: [12] + a_sl: 2.25 + atol_xi: 1 + dot_gamma_0_sl: 0.001 + h_0_sl-sl: 75.0e+6 + h_sl-sl: [1, 1, 1.4, 1.4, 1.4, 1.4, 1.4] + n_sl: 20 + output: [xi_sl] + xi_0_sl: [31.0e+6] + xi_inf_sl: [63.0e+6] + damask_post_processing: + - name: add_stress_Cauchy + args: { P: P, F: F } + opts: { add_Mises: true } + - name: add_strain + args: { F: F, t: V, m: 0 } + opts: { add_Mises: true } + - name: add_strain + args: { F: F_p, t: V, m: 0 } + opts: { add_Mises: true } + - name: add_IPF_color + args: { l: [0, 0, 1] } + VE_response_data: + phase_data: + - field_name: sigma_vM + phase_name: copper + out_name: vol_avg_equivalent_stress + transforms: [{ mean_along_axes: 1 }] + - field_name: epsilon_V^0(F)_vM + phase_name: copper + out_name: vol_avg_equivalent_strain + transforms: [{ mean_along_axes: 1 }] + - field_name: epsilon_V^0(F_p)_vM + phase_name: copper + out_name: vol_avg_equivalent_plastic_strain + transforms: [{ mean_along_axes: 1 }] + field_data: + - field_name: phase + - field_name: O + grain_data: + - field_name: O + increments: [{ values: [0, -1] }] + + - schema: simulate_VE_loading_JAX_CPFEM + inputs: + material_parameters: + C11: 1.684e5 # MPa + C12: 1.214e5 # MPa + C44: 0.754e5 # MPa + r: 1 # latent hardening + gss_initial: 60.8 # initial flow stress, in MPa + h: 541.5 # initial hardening, in MPa + t_sat: 109.8 # saturation strength, in MPa + gss_a: 2.5 # 'a' in Kalidini model + ao: 0.001 # reference strain rate + xm_d: 10 # rate sensitivity exponent (denominator) + slip_systems: # plane and direction + - [[1, 1, -1], [0, 1, 1]] + - [[1, 1, -1], [1, 0, 1]] + - [[1, 1, -1], [1, -1, 0]] + - [[1, -1, -1], [0, 1, -1]] + - [[1, -1, -1], [1, 0, 1]] + - [[1, -1, -1], [1, 1, 0]] + - [[1, -1, 1], [0, 1, 1]] + - [[1, -1, 1], [1, 0, -1]] + - [[1, -1, 1], [1, 1, 0]] + - [[1, 1, 1], [0, 1, -1]] + - [[1, 1, 1], [1, 0, -1]] + - [[1, 1, 1], [1, -1, 0]] + numerics: + newton_max_sub_step: 5 diff --git a/matflow/data/workflows/simulate_JAX_CPFEM_copper.yaml b/matflow/data/workflows/simulate_JAX_CPFEM_copper.yaml index db5a10a6..4c9389a9 100644 --- a/matflow/data/workflows/simulate_JAX_CPFEM_copper.yaml +++ b/matflow/data/workflows/simulate_JAX_CPFEM_copper.yaml @@ -1,13 +1,3 @@ -template_components: - environments: - - name: jax_cpfem - executables: - - label: run_jax_cpfem - instances: - - command: docker run --rm -it -v "${PWD}:/app" jax-cpfem - parallel_mode: null - num_cores: 1 - tasks: - schema: generate_volume_element_from_voronoi inputs: diff --git a/matflow/data/workflows/simulate_JAX_CPFEM_steel.yaml b/matflow/data/workflows/simulate_JAX_CPFEM_steel.yaml index 7170e993..63bd22ff 100644 --- a/matflow/data/workflows/simulate_JAX_CPFEM_steel.yaml +++ b/matflow/data/workflows/simulate_JAX_CPFEM_steel.yaml @@ -1,13 +1,3 @@ -template_components: - environments: - - name: jax_cpfem - executables: - - label: run_jax_cpfem - instances: - - command: docker run --rm -it -v "${PWD}:/app" jax-cpfem - parallel_mode: null - num_cores: 1 - # TODO: # - test mesh equivalence in polycrystal example, from provided gmsh .msh file # - (can just look at first sim output VTU file (contains cell oris map) From 0c273c278bbcfcf334d341bcbd21ec5f107f8186 Mon Sep 17 00:00:00 2001 From: Adam Date: Wed, 18 Jun 2025 21:53:16 +0100 Subject: [PATCH 10/11] fix: missing changes! --- .../jax_cpfem/write_problem_py_script.py | 115 ++---------------- .../template_components/command_files.yaml | 28 +++++ .../template_components/task_schemas.yaml | 43 +++++++ 3 files changed, 83 insertions(+), 103 deletions(-) diff --git a/matflow/data/scripts/jax_cpfem/write_problem_py_script.py b/matflow/data/scripts/jax_cpfem/write_problem_py_script.py index 5df5682e..f2247f8c 100644 --- a/matflow/data/scripts/jax_cpfem/write_problem_py_script.py +++ b/matflow/data/scripts/jax_cpfem/write_problem_py_script.py @@ -1,6 +1,10 @@ from textwrap import dedent, indent from collections.abc import Sequence +import numpy as np + +from matflow.param_classes.boundary_conditions import BoundaryCondition + _DIR_LOOKUP = { "x": "Lx", "y": "Ly", @@ -8,7 +12,7 @@ } -def write_problem_py_script(path, loading, solver_options): +def write_problem_py_script(path, load_case, solver_options): TEMPLATE = dedent( """\ @@ -68,62 +72,6 @@ def problem(): Lz = np.max(mesh.points[:, 2]) print(f"Domain size: {{Lx}}, {{Ly}}, {{Lz}}") - displacements = np.linspace(0, {strain}*{direction}, {num_increments} + 1) - ts = np.linspace(0, {total_time}, {num_increments} + 1) - - ## Hu: Define index of points and faces - def corner(point): - flag_x = np.isclose(point[0], 0.0, atol=1e-5) - flag_y = np.isclose(point[1], 0.0, atol=1e-5) - flag_z = np.isclose(point[2], Lz, atol=1e-5) - return np.logical_and(np.logical_and(flag_x, flag_y), flag_z) - - def corner2(point): - flag_x = np.isclose(point[0], 0.0, atol=1e-5) - flag_y = np.isclose(point[1], 0.0, atol=1e-5) - flag_z = np.isclose(point[2], 0.0, atol=1e-5) - return np.logical_and(np.logical_and(flag_x, flag_y), flag_z) - - def corner3(point): - flag_x = np.isclose(point[0], Lx, atol=1e-5) - flag_y = np.isclose(point[1], 0.0, atol=1e-5) - flag_z = np.isclose(point[2], 0.0, atol=1e-5) - return np.logical_and(np.logical_and(flag_x, flag_y), flag_z) - - def corner4(point): - flag_x = np.isclose(point[0], Lx, atol=1e-5) - flag_y = np.isclose(point[1], 0.0, atol=1e-5) - flag_z = np.isclose(point[2], Lz, atol=1e-5) - return np.logical_and(np.logical_and(flag_x, flag_y), flag_z) - - def left(point): - return np.isclose(point[0], 0.0, atol=1e-5) - - def right(point): - return np.isclose(point[0], Lx, atol=1e-5) - - def front(point): - return np.isclose(point[1], 0.0, atol=1e-5) - - def back(point): - return np.isclose(point[1], Ly, atol=1e-5) - - def bottom(point): - return np.isclose(point[2], 0.0, atol=1e-5) - - def top(point): - return np.isclose(point[2], Lz, atol=1e-5) - - ## Hu: Define dirichlet B.C. - def zero_dirichlet_val(point): - return 0.0 - - def get_dirichlet_top(disp): - def val_fn(point): - return disp - - return val_fn - {dirichlet_BCs} ## Hu: Define CPFEM problem on top of JAX-FEM @@ -156,7 +104,7 @@ def val_fn(point): ## Hu: Reset Dirichlet boundary conditions. ## Hu: Useful when a time-dependent problem is solved, and at each iteration the boundary condition needs to be updated. - dirichlet_bc_info[-1][{dirichlet_ut_idx}] = get_dirichlet_top(displacements[i + 1]) + dirichlet_bc_info[-1][{dirichlet_ut_idx}] = constant_value(displacements[i + 1]) problem.fes[0].update_Dirichlet_boundary_conditions(dirichlet_bc_info) ## Hu: Set up internal variables of previous step for inner Newton's method @@ -245,57 +193,18 @@ def val_fn(point): """ ) - DISP_COMP_LOOKUP = {"u_x": 0, "u_y": 1, "u_z": 2} - - u_t_idx = None # the of the BC where we specify the required displacement at time t - location_funcs = [] - vector_comps = [] - value_funcs = [] - for dbc in loading["dirichlet_BCs"]: - for disp_comp, value in dbc["values"].items(): - - if value == 0: - # a function that takes a point and returns zero: - value = "zero_dirichlet_val" - elif value.lower() == "u(t)": - # a function that takes a point and returns the required displacement at - # time t: - value = "get_dirichlet_top(displacements[0])" - u_t_idx = len(location_funcs) - - location_funcs.append(dbc["points"]) - vector_comps.append(DISP_COMP_LOOKUP[disp_comp]) - value_funcs.append(value) - - DBCs_TEMPLATE = dedent( - """\ - dirichlet_bc_info = [ - {location_functions_str}, - {vector_components_str}, - {value_functions_str}, - ] - """ - ) - INDENT = " " - DBCs = DBCs_TEMPLATE.format( - location_functions_str=indent("[" + ", ".join(location_funcs) + "]", INDENT), - vector_components_str=indent( - "[" + ", ".join(str(i) for i in vector_comps) + "]", INDENT - ), - value_functions_str=indent("[" + ", ".join(value_funcs) + "]", INDENT), + dirichlet_BCs, dirichlet_ut_idx = create_JAX_CPFEM_boundary_conditions_code( + load_case=load_case, domain_size=["Lx", "Ly", "Lz"] ) + INDENT = " " solver_name = solver_options.pop("name") with path.open("wt") as fh: fh.write( TEMPLATE.format( case_name="polycrystal", - strain=loading["strain"], - num_increments=loading["num_increments"], - direction=_DIR_LOOKUP[loading["direction"].lower()], - total_time=loading["total_time"], - dirichlet_BCs=indent(DBCs, INDENT), - dirichlet_ut_idx=u_t_idx, + dirichlet_BCs=indent(dirichlet_BCs, INDENT), + dirichlet_ut_idx=dirichlet_ut_idx, solver_name=solver_name, solver_options=solver_options, ) @@ -343,7 +252,7 @@ def val_fn(point): def create_JAX_CPFEM_boundary_conditions_code( load_case, domain_size: Sequence[float | str] -) -> str: +) -> tuple[str, int]: """ Create a string containing Python code that can be used to define this load case (represented with Dirichlet boundary conditions) when using JAX-CPFEM. diff --git a/matflow/data/template_components/command_files.yaml b/matflow/data/template_components/command_files.yaml index 9c181f30..0116eb4a 100644 --- a/matflow/data/template_components/command_files.yaml +++ b/matflow/data/template_components/command_files.yaml @@ -69,3 +69,31 @@ name: name: pipeline.xdmf doc: DREAM.3D model data and metadata. + +- label: jax_cpfem_mesh_data + name: + name: mesh.npz + doc: > + Information that we can use to generate a JAX-FEM hexahedral ("box") `Mesh` + object, and associate each grain with quaternion orientations. +- label: jax_cpfem_model_script + name: + name: model.py + doc: > + JAX-CPFEM model definition script, including material parameter definitions. +- label: jax_cpfem_problem_script + name: + name: problem.py + doc: JAX-CPFEM problem definition script, including boundary condition definitions. +- label: jax_cpfem_slip_systems + name: + name: input_slip_sys.txt + doc: Slip systems to model. +- label: jax_cpfem_stdout + name: + name: stdout.log + doc: Standard output stream from a JAX-CPFEM simulation. +- label: jax_cpfem_stderr + name: + name: stderr.log + doc: Standard error stream from a JAX-CPFEM simulation. diff --git a/matflow/data/template_components/task_schemas.yaml b/matflow/data/template_components/task_schemas.yaml index 8d60cf52..ea3f8191 100644 --- a/matflow/data/template_components/task_schemas.yaml +++ b/matflow/data/template_components/task_schemas.yaml @@ -792,6 +792,49 @@ script: <> inputs: [damask_post_processing, VE_response_data, damask_viz] +- objective: simulate_VE_loading + doc: Simulate applying a load case to a volume element using JAX-CPFEM. + implementation: JAX_CPFEM + inputs: + - parameter: volume_element + - parameter: material_parameters + - parameter: slip_systems + - parameter: load_case + - parameter: solver_options + - parameter: numerics + outputs: + - parameter: CP_outputs # temp naming + actions: + - requires_dir: true + environments: + - scope: + type: main + environment: jax_cpfem + - scope: + type: processing + environment: python_env + input_file_generators: + - input_file: jax_cpfem_mesh_data + from_inputs: [volume_element] + script: <> + - input_file: jax_cpfem_model_script + from_inputs: [material_parameters, numerics] + script: <> + - input_file: jax_cpfem_problem_script + from_inputs: [load_case, solver_options] + script: <> + - input_file: jax_cpfem_slip_systems + from_inputs: [slip_systems] + script: <> + commands: + - command: <> problem + stdout: stdout.log + stderr: stderr.log + output_file_parsers: + CP_outputs: + from_files: [jax_cpfem_stdout] + script: <> + - objective: read_tensile_test doc: Read tensile test data from CSV. method: from_CSV From ddc5909077edcda77994dd880870bdf5658f8eb8 Mon Sep 17 00:00:00 2001 From: Adam Date: Thu, 19 Jun 2025 10:39:24 +0100 Subject: [PATCH 11/11] fix: missing future import --- matflow/data/scripts/jax_cpfem/write_problem_py_script.py | 1 + 1 file changed, 1 insertion(+) diff --git a/matflow/data/scripts/jax_cpfem/write_problem_py_script.py b/matflow/data/scripts/jax_cpfem/write_problem_py_script.py index f2247f8c..7fe6c917 100644 --- a/matflow/data/scripts/jax_cpfem/write_problem_py_script.py +++ b/matflow/data/scripts/jax_cpfem/write_problem_py_script.py @@ -1,3 +1,4 @@ +from __future__ import annotations from textwrap import dedent, indent from collections.abc import Sequence