Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 182 additions & 9 deletions tool/asm_to_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@


PE_HEADER_RE = re.compile(r"^PE\((\d+),\s*(\d+)\):")
PE_BROADCAST_RE = re.compile(r"^PE\(\s*\*\s*,\s*\*\s*\):")
GROUP_END_RE = re.compile(r"^\}\s*\(idx_per_ii=(\d+)\)")
OP_META_RE = re.compile(r"\(t=(\d+),\s*inv_iters=(\d+)\)")
OPERAND_RE = re.compile(r"\[([^\]]+)\]")
COMPILED_II_RE = re.compile(r"#\s*Compiled II:\s*(\d+)")
ARRAY_SIZE_RE = re.compile(r"#\s*Array\s*Size:\s*(\d+)\s*[xX]\s*(\d+)")
GEMM_HEADER_RE = re.compile(
r"#\s*GEMM\s+(?:Shape\s*)?:\s*M\s*=\s*(\d+)\s*[,\s]+\s*N\s*=\s*(\d+)\s*[,\s]+\s*K\s*=\s*(\d+)"
)
REGISTER_RE = re.compile(r"^\$(\d+)$")


@dataclass
Expand All @@ -47,6 +53,39 @@ class CoreProgram:
instruction_groups: List[InstructionGroup] = field(default_factory=list)


@dataclass
class GemmMeta:
m: int
n: int
k: int


def get_register_stride(template_groups: List[InstructionGroup]) -> int:
max_reg_idx = -1
for group in template_groups:
for op in group.operations:
for operand in op.src_operands + op.dst_operands:
operand_str = operand.get("operand", "")
match = REGISTER_RE.match(operand_str)
if match:
reg_idx = int(match.group(1))
if reg_idx > max_reg_idx:
max_reg_idx = reg_idx
if max_reg_idx < 0:
return 0
return max_reg_idx + 1


def remap_register_operand(operand_value: str, block_idx: int, register_stride: int) -> str:
if register_stride <= 0:
return operand_value
match = REGISTER_RE.match(operand_value)
if not match:
return operand_value
reg_idx = int(match.group(1))
return f"${reg_idx + block_idx * register_stride}"


def parse_operand(token: str) -> Dict[str, str]:
parts = [p.strip() for p in token.split(",") if p.strip()]
if not parts:
Expand Down Expand Up @@ -89,15 +128,98 @@ def parse_operation_line(line: str) -> Optional[Tuple[str, int, int, List[Dict[s
return opcode, time_step, invalid_iters, src_operands, dst_operands


def expand_simd_template(
template_groups: List[InstructionGroup],
columns: int,
rows: int,
gemm_meta: Optional[GemmMeta] = None,
) -> List[CoreProgram]:
cores: List[CoreProgram] = []
next_op_id = 0
block_indices = [0]
register_stride = 0

if gemm_meta is not None:
if gemm_meta.m % rows != 0 or gemm_meta.n % columns != 0:
raise ValueError(
f"GEMM dimensions must be divisible by array size: M={gemm_meta.m}, N={gemm_meta.n}, rows={rows}, columns={columns}"
)
num_block_m = gemm_meta.m // rows
num_block_n = gemm_meta.n // columns
block_indices = [
block_row * num_block_n + block_col
for block_row in range(num_block_m)
for block_col in range(num_block_n)
]
register_stride = get_register_stride(template_groups)

for row in range(rows):
for column in range(columns):
groups: List[InstructionGroup] = []
for block_idx in block_indices:
for group in template_groups:
operations: List[Operation] = []
for op in group.operations:
src_operands = [dict(item) for item in op.src_operands]
dst_operands = [dict(item) for item in op.dst_operands]
for operand in src_operands:
operand["operand"] = remap_register_operand(
operand["operand"], block_idx, register_stride
)
for operand in dst_operands:
operand["operand"] = remap_register_operand(
operand["operand"], block_idx, register_stride
)
if gemm_meta is not None and op.opcode == "STORE":
if len(src_operands) < 2:
raise ValueError(
"GEMM block expansion expects STORE to have value and address src operands."
)
src_operands[1]["operand"] = str(block_idx)
operations.append(
Operation(
opcode=op.opcode,
time_step=op.time_step,
invalid_iterations=op.invalid_iterations,
src_operands=src_operands,
dst_operands=dst_operands,
op_id=next_op_id,
)
)
next_op_id += 1
groups.append(
InstructionGroup(
index_per_ii=group.index_per_ii, operations=operations
)
)
core_id = str(row * columns + column)
cores.append(
CoreProgram(
column=column,
row=row,
core_id=core_id,
instruction_groups=groups,
)
)
return cores


def parse_asm(lines: Iterable[str]) -> Tuple[List[CoreProgram], int, int, int]:
compiled_ii: Optional[int] = None
cores: Dict[Tuple[int, int], CoreProgram] = {}
core_order: List[Tuple[int, int]] = []
template_groups: List[InstructionGroup] = []
gemm_meta: Optional[GemmMeta] = None
max_x = -1
max_y = -1
op_id = 0
array_rows: Optional[int] = None
array_columns: Optional[int] = None
seen_explicit_pe = False
seen_broadcast_pe = False

current_coord: Optional[Tuple[int, int]] = None
current_is_broadcast = False
in_group = False
group_lines: List[str] = []

Expand All @@ -109,13 +231,42 @@ def parse_asm(lines: Iterable[str]) -> Tuple[List[CoreProgram], int, int, int]:
compiled_match = COMPILED_II_RE.match(line)
if compiled_match:
compiled_ii = int(compiled_match.group(1))
array_size_match = ARRAY_SIZE_RE.match(line)
if array_size_match:
array_rows = int(array_size_match.group(1))
array_columns = int(array_size_match.group(2))
gemm_header_match = GEMM_HEADER_RE.match(line)
if gemm_header_match:
if gemm_meta is not None:
raise ValueError("Multiple '# GEMM:' headers are not supported.")
gemm_m = int(gemm_header_match.group(1))
gemm_n = int(gemm_header_match.group(2))
gemm_k = int(gemm_header_match.group(3))
if gemm_m <= 0 or gemm_n <= 0 or gemm_k <= 0:
raise ValueError("GEMM dimensions must be positive in '# GEMM: M= N= K='.")
gemm_meta = GemmMeta(m=gemm_m, n=gemm_n, k=gemm_k)
continue

pe_broadcast_match = PE_BROADCAST_RE.match(line)
if pe_broadcast_match:
if seen_explicit_pe:
raise ValueError("Cannot mix PE(*,*) with explicit PE(x,y) blocks.")
if seen_broadcast_pe:
raise ValueError("Multiple PE(*,*) headers are not supported.")
seen_broadcast_pe = True
current_coord = None
current_is_broadcast = True
continue

pe_match = PE_HEADER_RE.match(line)
if pe_match:
if seen_broadcast_pe:
raise ValueError("Cannot mix explicit PE(x,y) blocks with PE(*,*).")
seen_explicit_pe = True
x = int(pe_match.group(1))
y = int(pe_match.group(2))
current_coord = (x, y)
current_is_broadcast = False
if current_coord not in cores:
core_id = str(len(cores))
cores[current_coord] = CoreProgram(column=x, row=y, core_id=core_id)
Expand All @@ -125,7 +276,7 @@ def parse_asm(lines: Iterable[str]) -> Tuple[List[CoreProgram], int, int, int]:
continue

if line == "{":
if current_coord is None:
if not current_is_broadcast and current_coord is None:
raise ValueError("Found instruction group without a PE header.")
in_group = True
group_lines = []
Expand Down Expand Up @@ -157,30 +308,52 @@ def parse_asm(lines: Iterable[str]) -> Tuple[List[CoreProgram], int, int, int]:
)
)
op_id += 1
cores[current_coord].instruction_groups.append(
InstructionGroup(index_per_ii=index_per_ii, operations=ops)
)
instruction_group = InstructionGroup(index_per_ii=index_per_ii, operations=ops)
if current_is_broadcast:
template_groups.append(instruction_group)
else:
if current_coord is None:
raise ValueError("Found explicit instruction group without PE(x,y) context.")
cores[current_coord].instruction_groups.append(instruction_group)
in_group = False
group_lines = []
continue

if in_group:
group_lines.append(line)

if max_x < 0 or max_y < 0:
if seen_broadcast_pe:
if seen_explicit_pe:
raise ValueError("Cannot mix PE(*,*) with explicit PE(x,y) blocks.")
if array_rows is None or array_columns is None:
raise ValueError("SIMD asm with PE(*,*) requires '# Array Size: <rows>x<columns>'.")
if array_rows <= 0 or array_columns <= 0:
raise ValueError("Array size must be positive in '# Array Size: <rows>x<columns>'.")
if not template_groups:
raise ValueError("PE(*,*) template contains no instruction groups.")
columns = array_columns
rows = array_rows
ordered_cores = expand_simd_template(template_groups, columns, rows, gemm_meta)
else:
if gemm_meta is not None:
raise ValueError("GEMM header currently requires PE(*,*) template input.")
if max_x < 0 or max_y < 0:
raise ValueError("No PE blocks found in asm.")
columns = max_x + 1
rows = max_y + 1
ordered_cores = [cores[coord] for coord in core_order]

if not ordered_cores:
raise ValueError("No PE blocks found in asm.")

columns = max_x + 1
rows = max_y + 1
if compiled_ii is None:
max_idx = 0
for core in cores.values():
for core in ordered_cores:
for group in core.instruction_groups:
if group.index_per_ii > max_idx:
max_idx = group.index_per_ii
compiled_ii = max_idx + 1

ordered_cores = [cores[coord] for coord in core_order]
return ordered_cores, columns, rows, compiled_ii


Expand Down
13 changes: 7 additions & 6 deletions tool/viz/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ function parseJsonLines(text) {
for (const line of lines) {
try {
const obj = JSON.parse(line);
if (obj && Number.isInteger(obj.Time)) {
if (obj && typeof obj.Time === "number" && Number.isFinite(obj.Time)) {
obj.Time = Math.round(obj.Time);
rows.push(obj);
}
} catch (_) {
Expand All @@ -163,11 +164,11 @@ function indexByTime(events) {
let minTime = Number.POSITIVE_INFINITY;
let maxTime = Number.NEGATIVE_INFINITY;
for (const e of events) {
const t = e.Time;
if (!byTime.has(t)) byTime.set(t, []);
byTime.get(t).push(e);
minTime = Math.min(minTime, t);
maxTime = Math.max(maxTime, t);
const tKey = Math.round(Number(e.Time));
if (!byTime.has(tKey)) byTime.set(tKey, []);
byTime.get(tKey).push(e);
minTime = Math.min(minTime, tKey);
maxTime = Math.max(maxTime, tKey);
}
if (!Number.isFinite(minTime) || !Number.isFinite(maxTime)) {
minTime = 0;
Expand Down