static void validate_alu_src(nir_alu_instr *instr, unsigned index, validate_state *state) { nir_alu_src *src = &instr->src[index]; unsigned num_components; unsigned src_bit_size; if (src->src.is_ssa) { src_bit_size = src->src.ssa->bit_size; num_components = src->src.ssa->num_components; } else { src_bit_size = src->src.reg.reg->bit_size; if (src->src.reg.reg->is_packed) num_components = 4; /* can't check anything */ else num_components = src->src.reg.reg->num_components; } for (unsigned i = 0; i < 4; i++) { validate_assert(state, src->swizzle[i] < 4); if (nir_alu_instr_channel_used(instr, index, i)) validate_assert(state, src->swizzle[i] < num_components); } nir_alu_type src_type = nir_op_infos[instr->op].input_types[index]; /* 8-bit float isn't a thing */ if (nir_alu_type_get_base_type(src_type) == nir_type_float) validate_assert(state, src_bit_size == 16 || src_bit_size == 32 || src_bit_size == 64); if (nir_alu_type_get_type_size(src_type)) { /* This source has an explicit bit size */ validate_assert(state, nir_alu_type_get_type_size(src_type) == src_bit_size); } else { if (!nir_alu_type_get_type_size(nir_op_infos[instr->op].output_type)) { unsigned dest_bit_size = instr->dest.dest.is_ssa ? instr->dest.dest.ssa.bit_size : instr->dest.dest.reg.reg->bit_size; validate_assert(state, dest_bit_size == src_bit_size); } } validate_src(&src->src, state); }
static void validate_alu_instr(nir_alu_instr *instr, validate_state *state) { validate_assert(state, instr->op < nir_num_opcodes); unsigned instr_bit_size = 0; for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) { nir_alu_type src_type = nir_op_infos[instr->op].input_types[i]; unsigned src_bit_size = nir_src_bit_size(instr->src[i].src); if (nir_alu_type_get_type_size(src_type)) { validate_assert(state, src_bit_size == nir_alu_type_get_type_size(src_type)); } else if (instr_bit_size) { validate_assert(state, src_bit_size == instr_bit_size); } else { instr_bit_size = src_bit_size; } if (nir_alu_type_get_base_type(src_type) == nir_type_float) { /* 8-bit float isn't a thing */ validate_assert(state, src_bit_size == 16 || src_bit_size == 32 || src_bit_size == 64); } validate_alu_src(instr, i, state); } nir_alu_type dest_type = nir_op_infos[instr->op].output_type; unsigned dest_bit_size = nir_dest_bit_size(instr->dest.dest); if (nir_alu_type_get_type_size(dest_type)) { validate_assert(state, dest_bit_size == nir_alu_type_get_type_size(dest_type)); } else if (instr_bit_size) { validate_assert(state, dest_bit_size == instr_bit_size); } else { /* The only unsized thing is the destination so it's vacuously valid */ } if (nir_alu_type_get_base_type(dest_type) == nir_type_float) { /* 8-bit float isn't a thing */ validate_assert(state, dest_bit_size == 16 || dest_bit_size == 32 || dest_bit_size == 64); } validate_alu_dest(instr, state); }
static void validate_alu_dest(nir_alu_instr *instr, validate_state *state) { nir_alu_dest *dest = &instr->dest; unsigned dest_size = dest->dest.is_ssa ? dest->dest.ssa.num_components : dest->dest.reg.reg->num_components; bool is_packed = !dest->dest.is_ssa && dest->dest.reg.reg->is_packed; /* * validate that the instruction doesn't write to components not in the * register/SSA value */ validate_assert(state, is_packed || !(dest->write_mask & ~((1 << dest_size) - 1))); /* validate that saturate is only ever used on instructions with * destinations of type float */ nir_alu_instr *alu = nir_instr_as_alu(state->instr); validate_assert(state, (nir_alu_type_get_base_type(nir_op_infos[alu->op].output_type) == nir_type_float) || !dest->saturate); unsigned bit_size = dest->dest.is_ssa ? dest->dest.ssa.bit_size : dest->dest.reg.reg->bit_size; nir_alu_type type = nir_op_infos[instr->op].output_type; /* 8-bit float isn't a thing */ if (nir_alu_type_get_base_type(type) == nir_type_float) validate_assert(state, bit_size == 16 || bit_size == 32 || bit_size == 64); validate_assert(state, nir_alu_type_get_type_size(type) == 0 || nir_alu_type_get_type_size(type) == bit_size); validate_dest(&dest->dest, state); }
static bool constant_fold_alu_instr(nir_alu_instr *instr, void *mem_ctx) { nir_const_value src[NIR_MAX_VEC_COMPONENTS]; if (!instr->dest.dest.is_ssa) return false; /* In the case that any outputs/inputs have unsized types, then we need to * guess the bit-size. In this case, the validator ensures that all * bit-sizes match so we can just take the bit-size from first * output/input with an unsized type. If all the outputs/inputs are sized * then we don't need to guess the bit-size at all because the code we * generate for constant opcodes in this case already knows the sizes of * the types involved and does not need the provided bit-size for anything * (although it still requires to receive a valid bit-size). */ unsigned bit_size = 0; if (!nir_alu_type_get_type_size(nir_op_infos[instr->op].output_type)) bit_size = instr->dest.dest.ssa.bit_size; for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) { if (!instr->src[i].src.is_ssa) return false; if (bit_size == 0 && !nir_alu_type_get_type_size(nir_op_infos[instr->op].input_sizes[i])) { bit_size = instr->src[i].src.ssa->bit_size; } nir_instr *src_instr = instr->src[i].src.ssa->parent_instr; if (src_instr->type != nir_instr_type_load_const) return false; nir_load_const_instr* load_const = nir_instr_as_load_const(src_instr); for (unsigned j = 0; j < nir_ssa_alu_instr_src_components(instr, i); j++) { switch(load_const->def.bit_size) { case 64: src[i].u64[j] = load_const->value.u64[instr->src[i].swizzle[j]]; break; case 32: src[i].u32[j] = load_const->value.u32[instr->src[i].swizzle[j]]; break; case 16: src[i].u16[j] = load_const->value.u16[instr->src[i].swizzle[j]]; break; case 8: src[i].u8[j] = load_const->value.u8[instr->src[i].swizzle[j]]; break; default: unreachable("Invalid bit size"); } } /* We shouldn't have any source modifiers in the optimization loop. */ assert(!instr->src[i].abs && !instr->src[i].negate); } if (bit_size == 0) bit_size = 32; /* We shouldn't have any saturate modifiers in the optimization loop. */ assert(!instr->dest.saturate); nir_const_value dest = nir_eval_const_opcode(instr->op, instr->dest.dest.ssa.num_components, bit_size, src); nir_load_const_instr *new_instr = nir_load_const_instr_create(mem_ctx, instr->dest.dest.ssa.num_components, instr->dest.dest.ssa.bit_size); new_instr->value = dest; nir_instr_insert_before(&instr->instr, &new_instr->instr); nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(&new_instr->def)); nir_instr_remove(&instr->instr); ralloc_free(instr); return true; }