diff --git a/matflow/data/scripts/jax_cpfem/parse_stdout.py b/matflow/data/scripts/jax_cpfem/parse_stdout.py new file mode 100644 index 00000000..e4a4e85c --- /dev/null +++ b/matflow/data/scripts/jax_cpfem/parse_stdout.py @@ -0,0 +1,49 @@ +import os +import re +from pathlib import Path + +import numpy as np + +_PAT_FP = ( + r"Fp\s*=\s*\[\[\s*([-+]?\d*\.\d*\s+[-+]?\d*\.\d*\s+[-+]?\d*\.\d*)\s*\]\s*" + r"\[\s*([-+]?\d*\.\d*\s+[-+]?\d*\.\d*\s+[-+]?\d*\.\d*)\s*\]\s*" + r"\[\s*([-+]?\d*\.\d*\s+[-+]?\d*\.\d*\s+[-+]?\d*\.\d*)\s*\]\s*\]" +) +_PAT_key_data_list = R"{key}:\s*\[\s*((?:\d+\.\d+\s*)+)\]" +_PAT_STRESS_XX = _PAT_key_data_list.format(key="stress_xx") +_PAT_STRESS_YY = _PAT_key_data_list.format(key="stress_yy") +_PAT_STRESS_ZZ = _PAT_key_data_list.format(key="stress_zz") +_PAT_STRESS_VM = _PAT_key_data_list.format(key="von_mises_stress") + + +def parse_stdout(jax_cpfem_stdout: Path): + + encoding = "utf-16" if os.name == "nt" else "utf-8" # docker on windows nonsense + with jax_cpfem_stdout.open("rt", encoding=encoding) as fh: + contents = fh.read() + + fp = [] + for fp_res in re.findall(_PAT_FP, contents): + fp.append([[float(i) for i in fp_row.split()] for fp_row in fp_res]) + fp_arr = np.asarray(fp) + + stress_xx = np.array( + [float(i) for i in re.search(_PAT_STRESS_XX, contents).groups()[0].split()] + ) + stress_yy = np.array( + [float(i) for i in re.search(_PAT_STRESS_YY, contents).groups()[0].split()] + ) + stress_zz = np.array( + [float(i) for i in re.search(_PAT_STRESS_ZZ, contents).groups()[0].split()] + ) + stress_von_mises = np.array( + [float(i) for i in re.search(_PAT_STRESS_VM, contents).groups()[0].split()] + ) + + return { + "Fp": fp_arr, + "stress_xx": stress_xx, + "stress_yy": stress_yy, + "stress_zz": stress_zz, + "stress_von_mises": stress_von_mises, + } diff --git a/matflow/data/scripts/jax_cpfem/write_mesh.py b/matflow/data/scripts/jax_cpfem/write_mesh.py new file mode 100644 index 00000000..9a3497d6 --- /dev/null +++ b/matflow/data/scripts/jax_cpfem/write_mesh.py @@ -0,0 +1,14 @@ +import numpy as np + + +def write_mesh(path, volume_element): + np.savez( + path, + num_cells=volume_element["grid_size"][:], # i.e. Nx, Ny, Nz + domain=np.asarray(volume_element["size"]), # i.e. domain_x, domain_y, domain_z + cell_grain_indices=volume_element["element_material_idx"][:].reshape( + -1 + ), # TODO: check order! + quaternions=volume_element["orientations"]["quaternions"][:], + grain_orientation_indices=volume_element["constituent_orientation_idx"][:], + ) diff --git a/matflow/data/scripts/jax_cpfem/write_model_py_script.py b/matflow/data/scripts/jax_cpfem/write_model_py_script.py new file mode 100644 index 00000000..0857dd78 --- /dev/null +++ b/matflow/data/scripts/jax_cpfem/write_model_py_script.py @@ -0,0 +1,429 @@ +from textwrap import dedent + + +def write_model_py_script(path, material_parameters, numerics): + + TEMPLATE = dedent( + """\ + import numpy as onp + import jax + import jax.numpy as np + import jax.flatten_util + import os + import sys + from functools import partial + + from jax_fem.problem import Problem + from jax import config + + config.update("jax_enable_x64", True) + + onp.set_printoptions(threshold=sys.maxsize, linewidth=1000, suppress=True, precision=10) + + + crt_dir = os.path.dirname(__file__) + + + def rotate_tensor_rank_4(R, T): + R0 = R[:, :, None, None, None, None, None, None] + R1 = R[None, None, :, :, None, None, None, None] + R2 = R[None, None, None, None, :, :, None, None] + R3 = R[None, None, None, None, None, None, :, :] + return np.sum( + R0 * R1 * R2 * R3 * T[None, :, None, :, None, :, None, :], axis=(1, 3, 5, 7) + ) + + + def rotate_tensor_rank_2(R, T): + R0 = R[:, :, None, None] + R1 = R[None, None, :, :] + return np.sum(R0 * R1 * T[None, :, None, :], axis=(1, 3)) + + + rotate_tensor_rank_2_vmap = jax.vmap(rotate_tensor_rank_2, in_axes=(None, 0)) + + + def get_rot_mat(q): + ''' + Transformation from quaternion to the corresponding rotation matrix. + Reference: https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation + ''' + ## Hu: return rotation matrix -- (3,3) + return np.array( + [ + [ + q[0] * q[0] + q[1] * q[1] - q[2] * q[2] - q[3] * q[3], + 2 * q[1] * q[2] - 2 * q[0] * q[3], + 2 * q[1] * q[3] + 2 * q[0] * q[2], + ], + [ + 2 * q[1] * q[2] + 2 * q[0] * q[3], + q[0] * q[0] - q[1] * q[1] + q[2] * q[2] - q[3] * q[3], + 2 * q[2] * q[3] - 2 * q[0] * q[1], + ], + [ + 2 * q[1] * q[3] - 2 * q[0] * q[2], + 2 * q[2] * q[3] + 2 * q[0] * q[1], + q[0] * q[0] - q[1] * q[1] - q[2] * q[2] + q[3] * q[3], + ], + ] + ) + + + get_rot_mat_vmap = jax.vmap(get_rot_mat) + + + class CrystalPlasticity(Problem): + def custom_init(self, quat, cell_ori_inds): + r = {r} + self.gss_initial = {gss_initial} + + input_slip_sys = onp.loadtxt("input_slip_sys.txt") + num_slip_sys = len(input_slip_sys) + + slip_directions = input_slip_sys[:, self.dim :] + slip_directions = ( + slip_directions / onp.linalg.norm(slip_directions, axis=1)[:, None] + ) + slip_normals = input_slip_sys[:, : self.dim] + slip_normals = slip_normals / onp.linalg.norm(slip_normals, axis=1)[:, None] + + self.Schmid_tensors = jax.vmap(np.outer)(slip_directions, slip_normals) + + self.q = r * onp.ones((num_slip_sys, num_slip_sys)) + + num_directions_per_normal = 3 + for i in range(num_slip_sys): + for j in range(num_directions_per_normal): + self.q[ + i, i // num_directions_per_normal * num_directions_per_normal + j + ] = 1.0 + + rot_mats = onp.array(get_rot_mat_vmap(quat)[cell_ori_inds]) + + ### Note: for CPFEM, self.num_vars=1, which means fes only have 1 Finite Element object + ### Multi-physics CPFEM is under study + ## Hu: Fp - plastic deformation gradient + Fp_inv_gp = onp.repeat( + onp.repeat( + onp.eye(self.dim)[None, None, :, :], len(self.fes[0].cells), axis=0 + ), + self.fes[0].num_quads, + axis=1, + ) + ## Hu: slip resistance + slip_resistance_gp = self.gss_initial * onp.ones( + (len(self.fes[0].cells), self.fes[0].num_quads, num_slip_sys) + ) + ## Hu: slip rate + slip_gp = onp.zeros_like(slip_resistance_gp) + ## Hu: rotation matrix + rot_mats_gp = onp.repeat(rot_mats[:, None, :, :], self.fes[0].num_quads, axis=1) + ## Hu: elastic modulus + self.C = onp.zeros((self.dim, self.dim, self.dim, self.dim)) + + ## Hu: unit: MPa + C11 = {C11} + C12 = {C12} + C44 = {C44} + + self.C[0, 0, 0, 0] = C11 + self.C[1, 1, 1, 1] = C11 + self.C[2, 2, 2, 2] = C11 + + self.C[0, 0, 1, 1] = C12 + self.C[1, 1, 0, 0] = C12 + + self.C[0, 0, 2, 2] = C12 + self.C[2, 2, 0, 0] = C12 + + self.C[1, 1, 2, 2] = C12 + self.C[2, 2, 1, 1] = C12 + + self.C[1, 2, 1, 2] = C44 + self.C[1, 2, 2, 1] = C44 + self.C[2, 1, 1, 2] = C44 + self.C[2, 1, 2, 1] = C44 + + self.C[2, 0, 2, 0] = C44 + self.C[2, 0, 0, 2] = C44 + self.C[0, 2, 2, 0] = C44 + self.C[0, 2, 0, 2] = C44 + + self.C[0, 1, 0, 1] = C44 + self.C[0, 1, 1, 0] = C44 + self.C[1, 0, 0, 1] = C44 + self.C[1, 0, 1, 0] = C44 + + ## Hu: internal variables + self.internal_vars = [Fp_inv_gp, slip_resistance_gp, slip_gp, rot_mats_gp] + + def get_tensor_map(self): + tensor_map, _ = self.get_maps() + return tensor_map + + def get_maps(self): + ## Hu: initial hardening, unit: MPa + h = {h} + ## Hu: saturation strength, unit: MPa + t_sat = {t_sat} + ## Hu: 'a' in Kalidini model + gss_a = {gss_a} + ## Hu: reference strain rate + ao = {ao} + ## Hu: rate sensitivity exponent + xm = 1.0 / {xm_d} + + def get_partial_tensor_map(Fp_inv_old, slip_resistance_old, slip_old, rot_mat): + _, unflatten_fn = jax.flatten_util.ravel_pytree(Fp_inv_old) + _, unflatten_fn_params = jax.flatten_util.ravel_pytree( + [Fp_inv_old, Fp_inv_old, slip_resistance_old, slip_old, rot_mat] + ) + + def first_PK_stress(u_grad): + x, _ = jax.flatten_util.ravel_pytree( + [u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat] + ) + y = newton_solver(x) + S = unflatten_fn(y) + _, _, _, Fe, F = helper( + u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat, S + ) + sigma = 1.0 / np.linalg.det(Fe) * Fe @ S @ Fe.T + P = np.linalg.det(F) * sigma @ np.linalg.inv(F).T + return P + + def update_int_vars(u_grad): + x, _ = jax.flatten_util.ravel_pytree( + [u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat] + ) + y = newton_solver(x) + S = unflatten_fn(y) + Fp_inv_new, slip_resistance_new, slip_new, Fe, F = helper( + u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat, S + ) + return Fp_inv_new, slip_resistance_new, slip_new, rot_mat + + ## Hu: S-based formulation + def helper(u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat, S): + tau = np.sum( + S[None, :, :] + * rotate_tensor_rank_2_vmap(rot_mat, self.Schmid_tensors), + axis=(1, 2), + ) + gamma_inc = ( + ao + * self.dt + * np.absolute(tau / slip_resistance_old) ** (1.0 / xm) + * np.sign(tau) + ) + + tmp = ( + h + * np.absolute(gamma_inc) + * np.absolute(1 - slip_resistance_old / t_sat) ** gss_a + * np.sign(1 - slip_resistance_old / t_sat) + ) + g_inc = (self.q @ tmp[:, None]).reshape(-1) + + # tmp = h*np.absolute(gamma_inc) / np.cosh(h*np.sum(slip_old)/(t_sat - self.gss_initial))**2 + # g_inc = (self.q @ tmp[:, None]).reshape(-1) + + slip_resistance_new = slip_resistance_old + g_inc + slip_new = slip_old + gamma_inc + F = u_grad + np.eye(self.dim) + L_plastic_inc = np.sum( + gamma_inc[:, None, None] + * rotate_tensor_rank_2_vmap(rot_mat, self.Schmid_tensors), + axis=0, + ) + Fp_inv_new = Fp_inv_old @ (np.eye(self.dim) - L_plastic_inc) + Fe = F @ Fp_inv_new + return Fp_inv_new, slip_resistance_new, slip_new, Fe, F + + ## Hu: Calculate the residual function + def implicit_residual(x, y): + u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat = ( + unflatten_fn_params(x) + ) + S = unflatten_fn(y) + _, _, _, Fe, _ = helper( + u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat, S + ) + S_ = np.sum( + rotate_tensor_rank_4(rot_mat, self.C) + * 1.0 + / 2.0 + * (Fe.T @ Fe - np.eye(self.dim))[None, None, :, :], + axis=(2, 3), + ) + res, _ = jax.flatten_util.ravel_pytree(S - S_) + return res + + ## Hu: inner Newton's method + @jax.custom_jvp + def newton_solver(x): + # Critical change: The following line causes JAX (version 0.4.13) tracer error + # y0 = np.zeros_like(Fp_inv_old.reshape(-1)) + y0 = np.zeros(self.dim * self.dim) + + step = 0 + res_vec = implicit_residual(x, y0) + tol = 1e-8 + + def cond_fun(state): + step, res_vec, y = state + ## Hu: In MOOSE: while (rnorm > _rtol * rnorm0 && rnorm > _abs_tol && iteration < _maxiter) + return np.linalg.norm(res_vec) > tol + + def body_fun(state): + # Line search with decaying relaxation parameter (the "cut half" method). + # This is necessary since vanilla Newton's method may sometimes not converge. + # MOOSE has an implementation in C++, see the following link + # https://github.com/idaholab/moose/blob/next/modules/tensor_mechanics/src/materials/ + # crystal_plasticity/ComputeMultipleCrystalPlasticityStress.C#L634 + step, res_vec, y = state + ## Hu: Input: S, Output: res + f_partial = lambda y: implicit_residual(x, y) + jac = jax.jacfwd(f_partial)(y) + y_inc = np.linalg.solve(jac, -res_vec) + + relax_param_ini = 1.0 + sub_step_ini = 0 + max_sub_step = {newton_max_sub_step} + + def sub_cond_fun(state): + _, crt_res_vec, sub_step = state + return np.logical_and( + np.linalg.norm(crt_res_vec) >= np.linalg.norm(res_vec), + sub_step < max_sub_step, + ) + + def sub_body_fun(state): + relax_param, res_vec, sub_step = state + res_vec = f_partial(y + relax_param * y_inc) + return 0.5 * relax_param, res_vec, sub_step + 1 + + relax_param_f, res_vec_f, _ = jax.lax.while_loop( + sub_cond_fun, + sub_body_fun, + (relax_param_ini, res_vec, sub_step_ini), + ) + step_update = step + 1 + + return step_update, res_vec_f, y + 2.0 * relax_param_f * y_inc + + step_f, res_vec_f, y_f = jax.lax.while_loop( + cond_fun, body_fun, (step, res_vec, y0) + ) + + return y_f + + @newton_solver.defjvp + def f_jvp(primals, tangents): + (x,) = primals + (v,) = tangents + y = newton_solver(x) + jac_x = jax.jacfwd(implicit_residual, argnums=0)(x, y) + jac_y = jax.jacfwd(implicit_residual, argnums=1)(x, y) + jvp_result = np.linalg.solve(jac_y, -(jac_x @ v[:, None]).reshape(-1)) + return y, jvp_result + + return first_PK_stress, update_int_vars + + def tensor_map(u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat): + first_PK_stress, _ = get_partial_tensor_map( + Fp_inv_old, slip_resistance_old, slip_old, rot_mat + ) + return first_PK_stress(u_grad) + + def update_int_vars_map( + u_grad, Fp_inv_old, slip_resistance_old, slip_old, rot_mat + ): + _, update_int_vars = get_partial_tensor_map( + Fp_inv_old, slip_resistance_old, slip_old, rot_mat + ) + return update_int_vars(u_grad) + + return tensor_map, update_int_vars_map + + def update_int_vars_gp(self, sol, params): + _, update_int_vars_map = self.get_maps() + vmap_update_int_vars_map = jax.jit(jax.vmap(jax.vmap(update_int_vars_map))) + # (num_cells, 1, num_nodes, vec, 1) * (num_cells, num_quads, num_nodes, 1, dim) -> (num_cells, num_quads, num_nodes, vec, dim) + u_grads = ( + np.take(sol, self.fes[0].cells, axis=0)[:, None, :, :, None] + * self.fes[0].shape_grads[:, :, :, None, :] + ) + u_grads = np.sum(u_grads, axis=2) # (num_cells, num_quads, vec, dim) + Fp_inv_gp, slip_resistance_gp, slip_gp, rot_mats_gp = vmap_update_int_vars_map( + u_grads, *params + ) + # TODO + return [Fp_inv_gp, slip_resistance_gp, slip_gp, rot_mats_gp] + + def set_params(self, params): + self.internal_vars = params + + def inspect_interval_vars(self, params): + '''For post-processing only''' + Fp_inv_gp, slip_resistance_gp, slip_gp, rot_mats_gp = params + F_p = np.linalg.inv(Fp_inv_gp[0, 0]) + print(f"Fp = \\n{{F_p}}") + slip_resistance_0 = slip_resistance_gp[0, 0, 0] + print( + f"slip_resistance index 0 = {{slip_resistance_0}}, max slip_resistance = {{np.max(slip_resistance_gp)}}" + ) + return F_p[2, 2], slip_resistance_0, slip_gp[0, 0, 0] + + def compute_avg_stress(self, sol, params): + '''For post-processing only''' + # (num_cells, 1, num_nodes, vec, 1) * (num_cells, num_quads, num_nodes, 1, dim) -> (num_cells, num_quads, num_nodes, vec, dim) + u_grads = ( + np.take(sol, self.fes[0].cells, axis=0)[:, None, :, :, None] + * self.fes[0].shape_grads[:, :, :, None, :] + ) + u_grads = np.sum(u_grads, axis=2) # (num_cells, num_quads, vec, dim) + + partial_tensor_map, _ = self.get_maps() + vmap_partial_tensor_map = jax.jit(jax.vmap(jax.vmap(partial_tensor_map))) + P = vmap_partial_tensor_map(u_grads, *params) + + def P_to_sigma(P, F): + return 1.0 / np.linalg.det(F) * P @ F.T + + vvmap_P_to_sigma = jax.vmap(jax.vmap(P_to_sigma)) + F = u_grads + np.eye(self.dim)[None, None, :, :] + sigma = vvmap_P_to_sigma(P, F) + + sigma_cell_data = ( + np.sum(sigma * self.fes[0].JxW[:, :, None, None], 1) + / np.sum(self.fes[0].JxW, axis=1)[:, None, None] + ) + + # num_cells*num_quads, vec, dim) * (num_cells*num_quads, 1, 1) + avg_P = np.sum( + P.reshape(-1, self.fes[0].vec, self.dim) + * self.fes[0].JxW.reshape(-1)[:, None, None], + 0, + ) / np.sum(self.fes[0].JxW) + return sigma_cell_data + """ + ) + + with path.open("wt") as fh: + fh.write( + TEMPLATE.format( + C11=material_parameters["C11"], + C12=material_parameters["C12"], + C44=material_parameters["C44"], + r=material_parameters["r"], + gss_initial=material_parameters["gss_initial"], + h=material_parameters["h"], + t_sat=material_parameters["t_sat"], + gss_a=material_parameters["gss_a"], + ao=material_parameters["ao"], + xm_d=material_parameters["xm_d"], + newton_max_sub_step=numerics["newton_max_sub_step"], + ) + ) diff --git a/matflow/data/scripts/jax_cpfem/write_problem_py_script.py b/matflow/data/scripts/jax_cpfem/write_problem_py_script.py new file mode 100644 index 00000000..7fe6c917 --- /dev/null +++ b/matflow/data/scripts/jax_cpfem/write_problem_py_script.py @@ -0,0 +1,326 @@ +from __future__ import annotations +from textwrap import dedent, indent +from collections.abc import Sequence + +import numpy as np + +from matflow.param_classes.boundary_conditions import BoundaryCondition + +_DIR_LOOKUP = { + "x": "Lx", + "y": "Ly", + "z": "Lz", +} + + +def write_problem_py_script(path, load_case, solver_options): + + TEMPLATE = dedent( + """\ + import jax + import jax.numpy as np + import numpy as onp + import os + import time + + from jax_fem.solver import solver + from jax_fem.generate_mesh import box_mesh, Mesh + from jax_fem.utils import save_sol + + + from model import CrystalPlasticity + + os.environ["CUDA_VISIBLE_DEVICES"] = "2" + + case_name = "{case_name}" + + data_dir = os.path.join(os.path.dirname(__file__), "data") + vtk_dir = os.path.join(data_dir, f"vtk/{{case_name}}") + + + def load_mesh_file(): + mesh_data = onp.load("mesh.npz") + meshio_obj = box_mesh(*mesh_data["num_cells"], *mesh_data["domain"]) + mesh_obj = Mesh(meshio_obj.points, meshio_obj.cells_dict["hexahedron"], ele_type="HEX8") + return {{ + "mesh_obj": mesh_obj, + "cell_grain_indices": mesh_data["cell_grain_indices"], + "quaternions": mesh_data["quaternions"], + "grain_orientation_indices": mesh_data["grain_orientation_indices"], + }} + + + def problem(): + print(jax.lib.xla_bridge.get_backend().platform) + + ele_type = "HEX8" + + mesh_data = load_mesh_file() + mesh = mesh_data["mesh_obj"] + cell_grain_inds = mesh_data["cell_grain_indices"] + grain_oris_inds = mesh_data["grain_orientation_indices"] + cell_ori_inds = grain_oris_inds[cell_grain_inds] + quat = mesh_data["quaternions"] + + print(f"{{quat=!r}}") + print("cell_grain_inds", cell_grain_inds) + print("grain_oris_inds", grain_oris_inds) + print("cell_ori_inds", cell_ori_inds) + print("No. of total mesh points:", mesh.points.shape) + + Lx = np.max(mesh.points[:, 0]) + Ly = np.max(mesh.points[:, 1]) + Lz = np.max(mesh.points[:, 2]) + print(f"Domain size: {{Lx}}, {{Ly}}, {{Lz}}") + + {dirichlet_BCs} + + ## Hu: Define CPFEM problem on top of JAX-FEM + ## Xue, Tianju, et al. Computer Physics Communications 291 (2023): 108802. + problem = CrystalPlasticity( + mesh, + vec=3, + dim=3, + ele_type=ele_type, + dirichlet_bc_info=dirichlet_bc_info, + additional_info=(quat, cell_ori_inds), + ) + + sol_list = [np.zeros((problem.fes[0].num_total_nodes, problem.fes[0].vec))] + ## Hu: self.internal_vars = [Fp_inv_gp, slip_resistance_gp, slip_gp, rot_mats_gp] + params = problem.internal_vars + + results_to_save = [] + stress_plot = np.array([]) + stress_xx_plot = np.array([]) + stress_yy_plot = np.array([]) + stress_zz_plot = np.array([]) + von_mises_stress_plot = np.array([]) + + for i in range(len(ts) - 1): + problem.dt = ts[i + 1] - ts[i] + print( + f"\\nStep {{i + 1}} in {{len(ts) - 1}}, disp = {{displacements[i + 1]}}, dt = {{problem.dt}}" + ) + + ## Hu: Reset Dirichlet boundary conditions. + ## Hu: Useful when a time-dependent problem is solved, and at each iteration the boundary condition needs to be updated. + dirichlet_bc_info[-1][{dirichlet_ut_idx}] = constant_value(displacements[i + 1]) + problem.fes[0].update_Dirichlet_boundary_conditions(dirichlet_bc_info) + + ## Hu: Set up internal variables of previous step for inner Newton's method + ## self.internal_vars = [Fp_inv_gp, slip_resistance_gp, slip_gp, rot_mats_gp] + problem.set_params(params) + + ## Hu: JAX-FEM's solver for outer Newton's method + ## solver(problem, solver_options={{}}) + ## Examples: + ## (1) solver_options = {{'jax_solver': {{}}}} + ## (2) solver_options = {{'umfpack_solver': {{}}}} + ## (3) solver_options = {{'petsc_solver': {{'ksp_type': 'bcgsl', 'pc_type': 'jacobi'}}, 'initial_guess': some_guess}} + sol_list = solver( + problem, solver_options={{"{solver_name}": {solver_options}, "initial_guess": sol_list}} + ) + + ## Hu: Post-processing for aacroscopic Cauchy stress of each cell + print(f"Computing stress...") + sigma_cell_data = problem.compute_avg_stress(sol_list[0], params)[:, :, :] + sigma_cell_xx = sigma_cell_data[:, 0, 0] + sigma_cell_yy = sigma_cell_data[:, 1, 1] + sigma_cell_zz = sigma_cell_data[:, 2, 2] + sigma_cell_xy = sigma_cell_data[:, 0, 1] + sigma_cell_xz = sigma_cell_data[:, 0, 2] + sigma_cell_yz = sigma_cell_data[:, 1, 2] + sigma_cell_von_mises_stress = ( + 0.5 + * ( + (sigma_cell_xx - sigma_cell_yy) ** 2.0 + + (sigma_cell_yy - sigma_cell_zz) ** 2.0 + + (sigma_cell_zz - sigma_cell_xx) ** 2.0 + ) + + +3.0 * (sigma_cell_xy**2.0 + sigma_cell_yz**2.0 + sigma_cell_xz**2.0) + ) ** 0.5 + + stress_xx_plot = np.append(stress_xx_plot, np.mean(sigma_cell_xx)) + stress_yy_plot = np.append(stress_yy_plot, np.mean(sigma_cell_yy)) + stress_zz_plot = np.append(stress_zz_plot, np.mean(sigma_cell_zz)) + von_mises_stress_plot = np.append( + von_mises_stress_plot, np.mean(sigma_cell_von_mises_stress) + ) + print( + f"Average Cauchy stress: stress_xx = {{stress_xx_plot[-1]}}, stress_yy = {{stress_yy_plot[-1]}}, stress_zz = {{stress_zz_plot[-1]}}, \\ + vM_stress = {{von_mises_stress_plot[-1]}}, max stress = {{np.max(sigma_cell_data)}}" + ) + + ## Hu: Update internal variables + ## self.internal_vars = [Fp_inv_gp, slip_resistance_gp, slip_gp, rot_mats_gp] + print(f"Updating int vars...") + params = problem.update_int_vars_gp(sol_list[0], params) + F_p_zz, slip_resistance_0, slip_0 = problem.inspect_interval_vars(params) + + ## Hu: Post-processing for visualization + vtk_path = os.path.join(vtk_dir, f"u_{{i:03d}}.vtu") + save_sol( + problem.fes[0], + sol_list[0], + vtk_path, + cell_infos=[ + ("cell_ori_inds", cell_ori_inds), + ("sigma_xx", sigma_cell_xx), + ("sigma_yy", sigma_cell_yy), + ("sigma_zz", sigma_cell_zz), + ("von_Mises_stress", sigma_cell_von_mises_stress), + ], + ) + + print("*************") + print("grain_oris_inds:\\n", grain_oris_inds) + print("stress_xx:\\n", onp.array(stress_xx_plot, order="F", dtype=onp.float64)) + print("stress_yy:\\n", onp.array(stress_yy_plot, order="F", dtype=onp.float64)) + print("stress_zz:\\n", onp.array(stress_zz_plot, order="F", dtype=onp.float64)) + print( + "von_mises_stress:\\n", + onp.array(von_mises_stress_plot, order="F", dtype=onp.float64), + ) + print("*************") + + + if __name__ == "__main__": + start_time = time.time() + problem() + end_time = time.time() + run_time = end_time - start_time + print("Simulation time:", run_time) + """ + ) + + dirichlet_BCs, dirichlet_ut_idx = create_JAX_CPFEM_boundary_conditions_code( + load_case=load_case, domain_size=["Lx", "Ly", "Lz"] + ) + + INDENT = " " + solver_name = solver_options.pop("name") + with path.open("wt") as fh: + fh.write( + TEMPLATE.format( + case_name="polycrystal", + dirichlet_BCs=indent(dirichlet_BCs, INDENT), + dirichlet_ut_idx=dirichlet_ut_idx, + solver_name=solver_name, + solver_options=solver_options, + ) + ) + + +def __apply_for_values( + dir_BC, func_name, loc_funcs, vec_comps, value_funcs +) -> tuple[str, int | None]: + out = "" + dirichlet_update_idx = None + for field_dir, value in dir_BC.value.items(): + loc_funcs.append(func_name) + vec_comps.append(BoundaryCondition._DIRS.index(field_dir)) + if value == 0: + # a function that takes a point and returns zero: + value_str = "zero_value" + if value_str not in value_funcs: + out += dedent( + """\ + def zero_value(point): + return 0 + + """ + ) + elif value.lower() == "u(t)": + # a function that takes a point and returns the required + # displacement at time t: + dirichlet_update_idx = len(value_funcs) + value_str = "constant_value(displacements[0])" + if value_str not in value_funcs: + out += dedent( + """\ + def constant_value(value): + def val_fn(point): + return value + return val_fn + + """ + ) + + value_funcs.append(value_str) + return out, dirichlet_update_idx + + +def create_JAX_CPFEM_boundary_conditions_code( + load_case, domain_size: Sequence[float | str] +) -> tuple[str, int]: + """ + Create a string containing Python code that can be used to define this + load case (represented with Dirichlet boundary conditions) when using JAX-CPFEM. + """ + + dir_BCs = load_case.to_dirichlet_BCs() + step = load_case.steps[0] + + domain_size_arr = np.asarray(domain_size) + func_lst = set() + out = "" + dirichlet_update_idx = None + loc_funcs = [] + vec_comps = [] + value_funcs = [] + for dir_BC in dir_BCs: + corners = dir_BC.corners + faces = dir_BC.faces + num_regions = len(corners) if corners else len(faces) + for corner in corners or (): + func_name = BoundaryCondition._JAX_CPFEM_BC_FUNC_CORNER_MAP[tuple(corner)] + if func_name not in func_lst: + func_lst.update(func_name) + # apply domain size: + corner_arr = np.asarray(corner) + is_max = corner_arr == 1 + corner_arr[is_max] = domain_size_arr[is_max] + out += BoundaryCondition.get_corner_location_func_str(corner_arr) + out_i, update_idx = __apply_for_values( + dir_BC, func_name, loc_funcs, vec_comps, value_funcs + ) + out += out_i + + if update_idx is not None: + dirichlet_update_idx = update_idx + + for face in faces or (): + func_name = BoundaryCondition._JAX_CPFEM_BC_FUNC_FACE_MAP[face] + if func_name not in func_lst: + func_lst.update(func_name) + out += BoundaryCondition.get_face_location_func_str(face, domain_size) + out_i, update_idx = __apply_for_values( + dir_BC, func_name, loc_funcs, vec_comps, value_funcs + ) + out += out_i + if update_idx is not None: + dirichlet_update_idx = update_idx + + out += dedent( + """\ + displacements = np.linspace(0, {strain}*{domain_i}, {num_increments} + 1) + ts = np.linspace(0, {total_time}, {num_increments} + 1) + + dirichlet_bc_info = [ + {location_functions_str}, + {vector_components_str}, + {value_functions_str}, + ] + """ + ).format( + strain=step.strain, + domain_i=_DIR_LOOKUP[step.direction], + num_increments=step.num_increments, + total_time=step.total_time, + location_functions_str="[" + ", ".join(loc_funcs) + "]", + vector_components_str="[" + ", ".join(str(i) for i in vec_comps) + "]", + value_functions_str="[" + ", ".join(value_funcs) + "]", + ) + + return out, dirichlet_update_idx diff --git a/matflow/data/scripts/jax_cpfem/write_slip_systems.py b/matflow/data/scripts/jax_cpfem/write_slip_systems.py new file mode 100644 index 00000000..7065cada --- /dev/null +++ b/matflow/data/scripts/jax_cpfem/write_slip_systems.py @@ -0,0 +1,26 @@ +from textwrap import dedent + + +def write_slip_systems(path, slip_systems): + # slip plane normals, followed by slip directions + # FCC (?) + + # TEMPLATE = dedent( + # """\ + # 1 1 -1 0 1 1 + # 1 1 -1 1 0 1 + # 1 1 -1 1 -1 0 + # 1 -1 -1 0 1 -1 + # 1 -1 -1 1 0 1 + # 1 -1 -1 1 1 0 + # 1 -1 1 0 1 1 + # 1 -1 1 1 0 -1 + # 1 -1 1 1 1 0 + # 1 1 1 0 1 -1 + # 1 1 1 1 0 -1 + # 1 1 1 1 -1 0 + # """ + # ) + with path.open("wt") as fh: + for slip_plane, slip_dir in slip_systems: + fh.write(" ".join(str(i) for i in [*slip_plane, *slip_dir]) + "\n") diff --git a/matflow/data/template_components/command_files.yaml b/matflow/data/template_components/command_files.yaml index 9c181f30..0116eb4a 100644 --- a/matflow/data/template_components/command_files.yaml +++ b/matflow/data/template_components/command_files.yaml @@ -69,3 +69,31 @@ name: name: pipeline.xdmf doc: DREAM.3D model data and metadata. + +- label: jax_cpfem_mesh_data + name: + name: mesh.npz + doc: > + Information that we can use to generate a JAX-FEM hexahedral ("box") `Mesh` + object, and associate each grain with quaternion orientations. +- label: jax_cpfem_model_script + name: + name: model.py + doc: > + JAX-CPFEM model definition script, including material parameter definitions. +- label: jax_cpfem_problem_script + name: + name: problem.py + doc: JAX-CPFEM problem definition script, including boundary condition definitions. +- label: jax_cpfem_slip_systems + name: + name: input_slip_sys.txt + doc: Slip systems to model. +- label: jax_cpfem_stdout + name: + name: stdout.log + doc: Standard output stream from a JAX-CPFEM simulation. +- label: jax_cpfem_stderr + name: + name: stderr.log + doc: Standard error stream from a JAX-CPFEM simulation. diff --git a/matflow/data/template_components/task_schemas.yaml b/matflow/data/template_components/task_schemas.yaml index 8d60cf52..ea3f8191 100644 --- a/matflow/data/template_components/task_schemas.yaml +++ b/matflow/data/template_components/task_schemas.yaml @@ -792,6 +792,49 @@ script: <> inputs: [damask_post_processing, VE_response_data, damask_viz] +- objective: simulate_VE_loading + doc: Simulate applying a load case to a volume element using JAX-CPFEM. + implementation: JAX_CPFEM + inputs: + - parameter: volume_element + - parameter: material_parameters + - parameter: slip_systems + - parameter: load_case + - parameter: solver_options + - parameter: numerics + outputs: + - parameter: CP_outputs # temp naming + actions: + - requires_dir: true + environments: + - scope: + type: main + environment: jax_cpfem + - scope: + type: processing + environment: python_env + input_file_generators: + - input_file: jax_cpfem_mesh_data + from_inputs: [volume_element] + script: <> + - input_file: jax_cpfem_model_script + from_inputs: [material_parameters, numerics] + script: <> + - input_file: jax_cpfem_problem_script + from_inputs: [load_case, solver_options] + script: <> + - input_file: jax_cpfem_slip_systems + from_inputs: [slip_systems] + script: <> + commands: + - command: <> problem + stdout: stdout.log + stderr: stderr.log + output_file_parsers: + CP_outputs: + from_files: [jax_cpfem_stdout] + script: <> + - objective: read_tensile_test doc: Read tensile test data from CSV. method: from_CSV diff --git a/matflow/data/workflows/simulate_JAX_CPFEM_DAMASK.yaml b/matflow/data/workflows/simulate_JAX_CPFEM_DAMASK.yaml new file mode 100644 index 00000000..8e14a0fb --- /dev/null +++ b/matflow/data/workflows/simulate_JAX_CPFEM_DAMASK.yaml @@ -0,0 +1,117 @@ +tasks: + - schema: generate_volume_element_from_voronoi + inputs: + periodic: false + VE_grid_size: [10, 10, 10] + orientations: + data: + - [1, 0, 0, 0] + unit_cell_alignment: { x: a, y: b, z: c } + representation: + type: quaternion + quat_order: scalar_vector + microstructure_seeds::from_random: + num_seeds: 1 + box_size: [0.1, 0.1, 0.1] + phase_label: copper + + - schema: visualise_VE_VTK + + - schema: define_load_case + inputs: + load_case::uniaxial: + total_time: 0.1 + num_increments: 10 + direction: z + target_strain: 0.01 + dump_frequency: 1 + + - schema: simulate_VE_loading_damask + inputs: + homogenization: + SX: + mechanical: { type: "pass" } + N_constituents: 1 + damask_phases: + copper: # TODO: # change to Cu material parameters! + lattice: cF + mechanical: + output: [F, P, F_e, F_p, L_p, O] + elastic: + type: Hooke + C_11: 106750000000 + C_12: 60410000000 + C_44: 28340000000 + plastic: + type: phenopowerlaw + N_sl: [12] + a_sl: 2.25 + atol_xi: 1 + dot_gamma_0_sl: 0.001 + h_0_sl-sl: 75.0e+6 + h_sl-sl: [1, 1, 1.4, 1.4, 1.4, 1.4, 1.4] + n_sl: 20 + output: [xi_sl] + xi_0_sl: [31.0e+6] + xi_inf_sl: [63.0e+6] + damask_post_processing: + - name: add_stress_Cauchy + args: { P: P, F: F } + opts: { add_Mises: true } + - name: add_strain + args: { F: F, t: V, m: 0 } + opts: { add_Mises: true } + - name: add_strain + args: { F: F_p, t: V, m: 0 } + opts: { add_Mises: true } + - name: add_IPF_color + args: { l: [0, 0, 1] } + VE_response_data: + phase_data: + - field_name: sigma_vM + phase_name: copper + out_name: vol_avg_equivalent_stress + transforms: [{ mean_along_axes: 1 }] + - field_name: epsilon_V^0(F)_vM + phase_name: copper + out_name: vol_avg_equivalent_strain + transforms: [{ mean_along_axes: 1 }] + - field_name: epsilon_V^0(F_p)_vM + phase_name: copper + out_name: vol_avg_equivalent_plastic_strain + transforms: [{ mean_along_axes: 1 }] + field_data: + - field_name: phase + - field_name: O + grain_data: + - field_name: O + increments: [{ values: [0, -1] }] + + - schema: simulate_VE_loading_JAX_CPFEM + inputs: + material_parameters: + C11: 1.684e5 # MPa + C12: 1.214e5 # MPa + C44: 0.754e5 # MPa + r: 1 # latent hardening + gss_initial: 60.8 # initial flow stress, in MPa + h: 541.5 # initial hardening, in MPa + t_sat: 109.8 # saturation strength, in MPa + gss_a: 2.5 # 'a' in Kalidini model + ao: 0.001 # reference strain rate + xm_d: 10 # rate sensitivity exponent (denominator) + slip_systems: # plane and direction + - [[1, 1, -1], [0, 1, 1]] + - [[1, 1, -1], [1, 0, 1]] + - [[1, 1, -1], [1, -1, 0]] + - [[1, -1, -1], [0, 1, -1]] + - [[1, -1, -1], [1, 0, 1]] + - [[1, -1, -1], [1, 1, 0]] + - [[1, -1, 1], [0, 1, 1]] + - [[1, -1, 1], [1, 0, -1]] + - [[1, -1, 1], [1, 1, 0]] + - [[1, 1, 1], [0, 1, -1]] + - [[1, 1, 1], [1, 0, -1]] + - [[1, 1, 1], [1, -1, 0]] + numerics: + newton_max_sub_step: 5 diff --git a/matflow/data/workflows/simulate_JAX_CPFEM_copper.yaml b/matflow/data/workflows/simulate_JAX_CPFEM_copper.yaml new file mode 100644 index 00000000..4c9389a9 --- /dev/null +++ b/matflow/data/workflows/simulate_JAX_CPFEM_copper.yaml @@ -0,0 +1,60 @@ +tasks: + - schema: generate_volume_element_from_voronoi + inputs: + periodic: false + VE_grid_size: [10, 10, 10] + orientations: + data: + - [1, 0, 0, 0] + unit_cell_alignment: { x: a, y: b, z: c } + representation: + type: quaternion + quat_order: scalar_vector + microstructure_seeds::from_random: + num_seeds: 1 + box_size: [0.1, 0.1, 0.1] + phase_label: copper + + - schema: visualise_VE_VTK + + - schema: define_load_case + inputs: + load_case::one_dimensional: + num_increments: 10 + total_time: 0.1 + target_strain: 0.01 + direction: x + + - schema: simulate_VE_loading_JAX_CPFEM + inputs: + material_parameters: + C11: 1.684e5 # MPa + C12: 1.214e5 # MPa + C44: 0.754e5 # MPa + r: 1 # latent hardening + gss_initial: 60.8 # initial flow stress, in MPa + h: 541.5 # initial hardening, in MPa + t_sat: 109.8 # saturation strength, in MPa + gss_a: 2.5 # 'a' in Kalidini model + ao: 0.001 # reference strain rate + xm_d: 10 # rate sensitivity exponent (denominator) + solver_options: + name: jax_solver + # name: petsc_solver + # ksp_type: bcgsl + # pc_type: jacobi + slip_systems: # plane and direction + - [[1, 1, -1], [0, 1, 1]] + - [[1, 1, -1], [1, 0, 1]] + - [[1, 1, -1], [1, -1, 0]] + - [[1, -1, -1], [0, 1, -1]] + - [[1, -1, -1], [1, 0, 1]] + - [[1, -1, -1], [1, 1, 0]] + - [[1, -1, 1], [0, 1, 1]] + - [[1, -1, 1], [1, 0, -1]] + - [[1, -1, 1], [1, 1, 0]] + - [[1, 1, 1], [0, 1, -1]] + - [[1, 1, 1], [1, 0, -1]] + - [[1, 1, 1], [1, -1, 0]] + numerics: + newton_max_sub_step: 5 diff --git a/matflow/data/workflows/simulate_JAX_CPFEM_steel.yaml b/matflow/data/workflows/simulate_JAX_CPFEM_steel.yaml new file mode 100644 index 00000000..63bd22ff --- /dev/null +++ b/matflow/data/workflows/simulate_JAX_CPFEM_steel.yaml @@ -0,0 +1,95 @@ +# TODO: +# - test mesh equivalence in polycrystal example, from provided gmsh .msh file +# - (can just look at first sim output VTU file (contains cell oris map) + +tasks: + - schema: generate_volume_element_from_voronoi + inputs: + periodic: false + VE_grid_size: [8, 8, 8] + orientations: + data: + - [ + -0.44841562434329973, + 0.16860994035907192, + -0.3509983118612716, + 0.8045460216342235, + ] + - [ + -0.1721519777399991, + 0.1221784375558591, + -0.7486299266954833, + -0.628481788767607, + ] + - [ + 0.1755902645854752, + -0.8369269701982096, + 0.21157163334810386, + 0.4732428018470682, + ] + - [ + -0.8663740014241987, + -0.187241901559973, + 0.4235298570897932, + -0.18697331389780547, + ] + - [ + -0.9366201723048194, + 0.07470626885605275, + 0.34088334168529333, + 0.03098666788741778, + ] + - [ + -0.462754124567833, + -0.3500710695060419, + -0.3866988406362011, + -0.7167795150120938, + ] + - [ + -0.6587757263479669, + -0.16019499696083303, + -0.3423669060109474, + 0.6504898208211397, + ] + - [ + 0.1985059035197109, + 0.6384482594496135, + 0.05206231874156126, + 0.7418010118898696, + ] + unit_cell_alignment: { x: a, y: b, z: c } + representation: + type: quaternion + quat_order: scalar_vector + microstructure_seeds::from_random: + num_seeds: 8 + box_size: [0.016, 0.016, 0.016] + phase_label: 304_steel + + - schema: visualise_VE_VTK + + - schema: define_load_case + inputs: + load_case::uniaxial: + num_increments: 10 + total_time: 0.1 + target_strain: 0.01 + direction: x + + - schema: simulate_VE_loading_JAX_CPFEM + inputs: + material_parameters: + C11: 2.622e5 # MPa + C12: 1.120e5 # MPa + C44: 0.746e5 # MPa + r: 1 # latent hardening + gss_initial: 90 # initial flow stress, in MPa + h: 392.9772 # initial hardening, in MPa + t_sat: 7295.1754 # saturation strength, in MPa + gss_a: 8 # 'a' in Kalidini model + ao: 0.001 # reference strain rate + xm_d: 120 # rate sensitivity exponent (denominator) + solver_options: + name: jax_solver + numerics: + newton_max_sub_step: 8 diff --git a/matflow/param_classes/boundary_conditions.py b/matflow/param_classes/boundary_conditions.py new file mode 100644 index 00000000..22889ac6 --- /dev/null +++ b/matflow/param_classes/boundary_conditions.py @@ -0,0 +1,173 @@ +""" +A boundary condition class to represent some types of load case in a way amenable to the +JAX-CPFEM crystal plasticity code. This is not a `ParameterValue` sub-class, but rather a +helper class. + +""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Literal +from textwrap import dedent + +from typing_extensions import Final, Self + + +class BoundaryCondition: + """Simple boundary conditions container that can be used to represent a subset of + particular load cases, as corners or faces of a unit box, and the value that should be + applied within those regions. + + Parameters + ---------- + corners + Corners of the 3D box for which boundary condition should apply. Each corner + should be specified as a three-tuple of 0s or 1s. For example `(0, 0, 0)` + corresponds to the box origin, and `(0, 1, 1)` corresponds to the `x=0`, `y=1`, + `z=1` corner. Specify either `corners` or `faces`. + faces + A sequence of strings, where each string identifies the normal direction of a + specified box face, when viewed from the middle of the box. For example, the `+z` + direction corresponds to the `z=1` face, and the `-x` direction corresponds to the + `x=0` face. Specify either `corners` or `faces`. + value + A value for one or more of the "x", "y", and "z" components of the field. This can + be specified as a number or a string label that has some application-dependent + meaning. + + """ + + _DIRS: Final[tuple[str, ...]] = ("x", "y", "z") + + _JAX_CPFEM_BC_FUNC_CORNER_MAP = { + (0, 0, 0): "corner_0", + (1, 0, 0): "corner_1", + (1, 1, 0): "corner_2", + (0, 1, 0): "corner_3", + (0, 0, 1): "corner_4", + (1, 0, 1): "corner_5", + (1, 1, 1): "corner_6", + (0, 1, 1): "corner_7", + } + _JAX_CPFEM_BC_FUNC_FACE_MAP = { + "+x": "face_pos_x", + "-x": "face_neg_x", + "+y": "face_pos_y", + "-y": "face_neg_y", + "+z": "face_pos_z", + "-z": "face_neg_z", + } + _LOCATION_FUNC_CORNER_BODY_TEMPLATE = ( + "return np.allclose(point, np.asarray({corner}), atol=1e-5)" + ) + _LOCATION_FUNC_FACE_BODY_TEMPLATE = ( + "return np.isclose(point[{comp_idx}], {value}, atol=1e-5)" + ) + + def __init__( + self, + value: Mapping[Literal["x", "y", "z"], float | str], + corners: ( + Sequence[Sequence[Literal[0, 1], Literal[0, 1], Literal[0, 1]]] | None + ) = None, + faces: Sequence[str] | None = None, + ): + if sum(i is not None for i in (corners, faces)) != 1: + raise ValueError("Specify exactly one of `corners` and `faces`.") + + self.corners = corners + self.faces = faces + self.value = value + + def __repr__(self): + region_name = "corners" if self.corners is not None else "faces" + region = self.corners if self.corners is not None else self.faces + return ( + f"{self.__class__.__name__}(" + f"{region_name}={region!r}, " + f"value={self.value!r}" + f")" + ) + + @classmethod + def uniaxial(cls, direction: Literal["x", "y", "z"]) -> list[Self]: + """Generate a list of boundary conditions that correspond to uniaxial loading + along the specified direction.""" + non_axial_dirs = sorted(list(set(cls._DIRS).difference(direction))) + return [ + cls(corners=([0, 0, 0],), value={non_axial_dirs[0]: 0, non_axial_dirs[1]: 0}), + cls(faces=(f"-{direction}",), value={direction: 0}), + cls(faces=(f"+{direction}",), value={direction: "u(t)"}), + ] + + @classmethod + def one_dimensional(cls, direction: Literal["x", "y", "z"]) -> list[Self]: + """Generate a list of boundary conditions that correspond to one-dimensional + loading along the specified direction (no deformation in non-axial directions).""" + non_axial_dirs = sorted(list(set(cls._DIRS).difference(direction))) + return [ + cls(faces=(f"-{direction}",), value={"x": 0, "y": 0, "z": 0}), + cls( + faces=(f"+{direction}",), + value={non_axial_dirs[0]: 0, non_axial_dirs[1]: 0, direction: "u(t)"}, + ), + ] + + @staticmethod + def get_corner_location_func_str(corner) -> str: + """Generate Python code that defines a function that return True if the provided + 3D point is located at the specified corner, and False otherwise. + + Notes + ----- + The generated code is used as part of the problem definition script when running a + JAX-CPFEM simulation. + + """ + func_str = dedent( + """\ + def {func_name}(point): + {func_body} + + """ + ).format( + func_name=BoundaryCondition._JAX_CPFEM_BC_FUNC_CORNER_MAP[tuple(corner)], + func_body=BoundaryCondition._LOCATION_FUNC_CORNER_BODY_TEMPLATE.format( + corner=f"[{', '.join(str(i) for i in corner)}]" + ), + ) + return func_str + + @staticmethod + def get_face_location_func_str(face, domain_size) -> str: + """Generate Python code that defines a function that return True if the provided + 3D point is located within the specified plane (see the `BoundaryCondition.faces` + documentation for specification details), and False otherwise. + + Notes + ----- + The generated code is used as part of the problem definition script when running a + JAX-CPFEM simulation. + + """ + face = face.lower() + if len(face) == 1: + face = f"+{face}" + sign, dir = face + + comp_idx = BoundaryCondition._DIRS.index(dir) + func_str = dedent( + """\ + def {func_name}(point): + {func_body} + + """ + ).format( + func_name=BoundaryCondition._JAX_CPFEM_BC_FUNC_FACE_MAP[face], + func_body=BoundaryCondition._LOCATION_FUNC_FACE_BODY_TEMPLATE.format( + comp_idx=comp_idx, + value=domain_size[comp_idx] if sign == "+" else 0, + ), + ) + return func_str diff --git a/matflow/param_classes/load.py b/matflow/param_classes/load.py index 07197566..eed86987 100644 --- a/matflow/param_classes/load.py +++ b/matflow/param_classes/load.py @@ -1,6 +1,7 @@ """ Loadings to apply to a simulated sample. """ + from __future__ import annotations from collections.abc import Callable, Iterator import copy @@ -18,6 +19,7 @@ import matflow as mf from matflow.param_classes.utils import masked_array_from_list +from matflow.param_classes.boundary_conditions import BoundaryCondition logger = logging.getLogger(__name__) @@ -83,6 +85,7 @@ class LoadStep(ParameterValue): """ _DIR_IDX: Final[tuple[str, ...]] = ("x", "y", "z") + _SUPPORTS_SCALAR_STRAIN: Final[tuple[str]] = ("uniaxial", "one_dimensional") def __init__( self, @@ -213,6 +216,24 @@ def type(self) -> str: """More user-friendly access to method name.""" return self._method_name or self.__class__.__name__ + @property + def strain(self) -> float | None: + """ + For a limited subset of load step types (e.g. uniaxial), return the scalar target + strain. + """ + if self.type in self._SUPPORTS_SCALAR_STRAIN: + return self.method_args["target_strain"] + + @property + def strain_rate(self) -> float | None: + """ + For a limited subset of load step types (e.g. uniaxial), return the scalar target + strain rate. + """ + if self.type in self._SUPPORTS_SCALAR_STRAIN: + return self.method_args["target_strain_rate"] + def __repr__(self) -> str: type_str = f"type={self.type!r}, " if self.type else "" if self.direction: @@ -228,6 +249,18 @@ def __repr__(self) -> str: f")" ) + def to_dirichlet_BCs(self) -> list[BoundaryCondition]: + """For some particular types of load steps (e.g. uniaxial), we can transform them + into Dirichlet boundary conditions.""" + + if self.type == "uniaxial": + return BoundaryCondition.uniaxial(self.direction) + elif self.type == "one_dimensional": + return BoundaryCondition.one_dimensional(self.direction) + raise NotImplementedError( + "Cannot express this load step in terms of boundary conditions." + ) + @classmethod def example_uniaxial(cls) -> Self: """ @@ -249,12 +282,69 @@ def example_uniaxial(cls) -> Self: target_def_grad_rate=rate, ) + @classmethod + def __pre_process_uniaxial_like_method_args( + cls, + direction: str, + target_strain: float | None = None, + target_strain_rate: float | None = None, + target_def_grad: float | None = None, + target_def_grad_rate: float | None = None, + ) -> tuple[int, float | None, float | None, float | None, float | None]: + """Perform common processing on a subset of arguments for `uniaxial` and + `one_dimensional` load case methods.""" + + # Validation: + msg = ( + "Specify either `target_strain`, `target_strain_rate`, " + "``target_def_grad` or target_def_grad_rate`." + ) + strain_arg = ( + target_strain, + target_strain_rate, + target_def_grad, + target_def_grad_rate, + ) + if sum(s is not None for s in strain_arg) != 1: + raise ValueError(msg) + + # convert strain (rate) to deformation gradient (rate) components, and ensure both + # strain(_rate) and def_grad(_rate) are populated: + if target_strain is not None: + target_def_grad = 1 + target_strain + elif target_def_grad is not None: + target_strain = target_def_grad - 1 + + if target_strain_rate is not None: + target_def_grad_rate = target_strain_rate + elif target_def_grad_rate is not None: + target_strain_rate = target_def_grad_rate + + try: + loading_dir_idx = cls._DIR_IDX.index(direction) + except ValueError: + msg = ( + f'Loading direction "{direction}" not allowed. It should be one of "x", ' + f'"y" or "z".' + ) + raise ValueError(msg) + + return ( + loading_dir_idx, + target_strain, + target_strain_rate, + target_def_grad, + target_def_grad_rate, + ) + @classmethod def uniaxial( cls, total_time: float | int, num_increments: int, direction: str, + target_strain: float | None = None, + target_strain_rate: float | None = None, target_def_grad_rate: float | None = None, target_def_grad: float | None = None, dump_frequency: int = 1, @@ -272,44 +362,49 @@ def uniaxial( A single character, "x", "y" or "z", representing the loading direction. target_def_grad : float Target deformation gradient to achieve along the loading direction component. + target_strain: float + Target engineering strain to achieve along the loading direction. Specify at + most one of `target_strain` and `target_def_grad`. target_def_grad_rate : float Target deformation gradient rate to achieve along the loading direction component. + target_strain_rate: float + Target engineering strain rate to achieve along the loading direction. Specify + at most one of `target_strain_rate` and `target_def_grad_rate`. dump_frequency : int, optional By default, 1, meaning results are written out every increment. """ _method_name = "uniaxial" + ( + loading_dir_idx, + target_strain, + target_strain_rate, + target_def_grad, + target_def_grad_rate, + ) = cls.__pre_process_uniaxial_like_method_args( + direction, + target_strain, + target_strain_rate, + target_def_grad, + target_def_grad_rate, + ) _method_args = { "total_time": total_time, "num_increments": num_increments, "direction": direction, - "target_def_grad_rate": target_def_grad_rate, + "target_strain": target_strain, + "target_strain_rate": target_strain_rate, "target_def_grad": target_def_grad, + "target_def_grad_rate": target_def_grad_rate, "dump_frequency": dump_frequency, } - # Validation: - msg = "Specify either `target_def_grad_rate` or `target_def_grad`." - if all([t is None for t in [target_def_grad_rate, target_def_grad]]): - raise ValueError(msg) - if all([t is not None for t in [target_def_grad_rate, target_def_grad]]): - raise ValueError(msg) - if target_def_grad_rate is not None: def_grad_val = target_def_grad_rate else: def_grad_val = target_def_grad - try: - loading_dir_idx = cls._DIR_IDX.index(direction) - except ValueError: - msg = ( - f'Loading direction "{direction}" not allowed. It should be one of "x", ' - f'"y" or "z".' - ) - raise ValueError(msg) - dg_arr = np.ma.masked_array(np.zeros((3, 3)), mask=np.eye(3)) stress_arr = np.ma.masked_array(np.zeros((3, 3)), mask=np.logical_not(np.eye(3))) @@ -331,6 +426,94 @@ def uniaxial( ) return obj._remember_name_args(_method_name, _method_args) + @classmethod + def one_dimensional( + cls, + total_time: float | int, + num_increments: int, + direction: str, + target_strain: float | None = None, + target_strain_rate: float | None = None, + target_def_grad_rate: float | None = None, + target_def_grad: float | None = None, + dump_frequency: int = 1, + ) -> Self: + """ + Generate a load step that deforms in only one dimension, such that the geometrical + result is like that for a material with a Poisson's ratio (nu) of zero. + + Parameters + ---------- + total_time + Total simulation time. + num_increments + Number of simulation increments. + direction : str + A single character, "x", "y" or "z", representing the loading direction. + target_def_grad : float + Target deformation gradient to achieve along the loading direction component. + target_strain: float + Target engineering strain to achieve along the loading direction. Specify at + most one of `target_strain` and `target_def_grad`. + target_def_grad_rate : float + Target deformation gradient rate to achieve along the loading direction + component. + target_strain_rate: float + Target engineering strain rate to achieve along the loading direction. Specify + at most one of `target_strain_rate` and `target_def_grad_rate`. + dump_frequency : int, optional + By default, 1, meaning results are written out every increment. + """ + + _method_name = "one_dimensional" + ( + loading_dir_idx, + target_strain, + target_strain_rate, + target_def_grad, + target_def_grad_rate, + ) = cls.__pre_process_uniaxial_like_method_args( + direction, + target_strain, + target_strain_rate, + target_def_grad, + target_def_grad_rate, + ) + _method_args = { + "total_time": total_time, + "num_increments": num_increments, + "direction": direction, + "target_strain": target_strain, + "target_strain_rate": target_strain_rate, + "target_def_grad": target_def_grad, + "target_def_grad_rate": target_def_grad_rate, + "dump_frequency": dump_frequency, + } + + if target_def_grad_rate is not None: + def_grad_val = target_def_grad_rate + else: + def_grad_val = target_def_grad + + dg_arr = np.ma.masked_array(np.zeros((3, 3)), mask=np.zeros((3, 3))) + dg_arr[loading_dir_idx, loading_dir_idx] = def_grad_val + + stress_arr = np.ma.masked_array(np.zeros((3, 3)), mask=np.ones((3, 3))) + + def_grad_aim = dg_arr if target_def_grad is not None else None + def_grad_rate = dg_arr if target_def_grad_rate is not None else None + + obj = cls( + direction=direction, + total_time=total_time, + num_increments=num_increments, + target_def_grad=def_grad_aim, + target_def_grad_rate=def_grad_rate, + stress=stress_arr, + dump_frequency=dump_frequency, + ) + return obj._remember_name_args(_method_name, _method_args) + @classmethod def biaxial( cls, @@ -997,6 +1180,18 @@ def create_damask_loading_plan(self) -> list[dict[str, Any]]: load_steps.append(dct) return load_steps + def to_dirichlet_BCs(self) -> list[BoundaryCondition]: + """For some particular types of load cases (e.g. uniaxial), we can transform them + into Dirichlet boundary conditions.""" + + if self.num_steps == 1: + return self.steps[0].to_dirichlet_BCs() + else: + raise NotImplementedError( + "It is not currently possible to express multi-step load cases in terms " + "of boundary conditions." + ) + @classmethod def uniaxial(cls, **kwargs) -> Self: """A single-step uniaxial load case. @@ -1006,6 +1201,15 @@ def uniaxial(cls, **kwargs) -> Self: """ return cls(steps=[LoadStep.uniaxial(**kwargs)]) + @classmethod + def one_dimensional(cls, **kwargs) -> Self: + """A single-step one-dimensional load case. + + See :py:meth:`~LoadStep.one_dimensional` for argument documentation. + + """ + return cls(steps=[LoadStep.one_dimensional(**kwargs)]) + @classmethod def biaxial(cls, **kwargs) -> Self: """A single-step biaxial load case.