/** * For a given starting writemask channel and corresponding source index in * the vec instruction, insert a MOV to the vec instruction's dest of all the * writemask channels that get read from the same src reg. * * Returns the writemask of our MOV, so the parent loop calling this knows * which ones have been processed. */ static unsigned insert_mov(nir_alu_instr *vec, unsigned start_channel, unsigned start_src_idx, nir_shader *shader) { unsigned src_idx = start_src_idx; assert(src_idx < nir_op_infos[vec->op].num_inputs); nir_alu_instr *mov = nir_alu_instr_create(shader, nir_op_imov); nir_alu_src_copy(&mov->src[0], &vec->src[src_idx], mov); nir_alu_dest_copy(&mov->dest, &vec->dest, mov); mov->dest.write_mask = (1u << start_channel); mov->src[0].swizzle[start_channel] = vec->src[src_idx].swizzle[0]; src_idx++; for (unsigned i = start_channel + 1; i < 4; i++) { if (!(vec->dest.write_mask & (1 << i))) continue; if (nir_srcs_equal(vec->src[src_idx].src, vec->src[start_src_idx].src)) { mov->dest.write_mask |= (1 << i); mov->src[0].swizzle[i] = vec->src[src_idx].swizzle[0]; } src_idx++; } nir_instr_insert_before(&vec->instr, &mov->instr); return mov->dest.write_mask; }
static bool match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, unsigned num_components, const uint8_t *swizzle, struct match_state *state) { uint8_t new_swizzle[4]; /* If the source is an explicitly sized source, then we need to reset * both the number of components and the swizzle. */ if (nir_op_infos[instr->op].input_sizes[src] != 0) { num_components = nir_op_infos[instr->op].input_sizes[src]; swizzle = identity_swizzle; } for (int i = 0; i < num_components; ++i) new_swizzle[i] = instr->src[src].swizzle[swizzle[i]]; switch (value->type) { case nir_search_value_expression: if (!instr->src[src].src.is_ssa) return false; if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu) return false; return match_expression(nir_search_value_as_expression(value), nir_instr_as_alu(instr->src[src].src.ssa->parent_instr), num_components, new_swizzle, state); case nir_search_value_variable: { nir_search_variable *var = nir_search_value_as_variable(value); assert(var->variable < NIR_SEARCH_MAX_VARIABLES); if (state->variables_seen & (1 << var->variable)) { if (!nir_srcs_equal(state->variables[var->variable].src, instr->src[src].src)) return false; assert(!instr->src[src].abs && !instr->src[src].negate); for (int i = 0; i < num_components; ++i) { if (state->variables[var->variable].swizzle[i] != new_swizzle[i]) return false; } return true; } else { if (var->is_constant && instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const) return false; if (var->type != nir_type_invalid) { if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu) return false; nir_alu_instr *src_alu = nir_instr_as_alu(instr->src[src].src.ssa->parent_instr); if (nir_op_infos[src_alu->op].output_type != var->type && !(var->type == nir_type_bool && alu_instr_is_bool(src_alu))) return false; } state->variables_seen |= (1 << var->variable); state->variables[var->variable].src = instr->src[src].src; state->variables[var->variable].abs = false; state->variables[var->variable].negate = false; for (int i = 0; i < 4; ++i) { if (i < num_components) state->variables[var->variable].swizzle[i] = new_swizzle[i]; else state->variables[var->variable].swizzle[i] = 0; } return true; } } case nir_search_value_constant: { nir_search_constant *const_val = nir_search_value_as_constant(value); if (!instr->src[src].src.is_ssa) return false; if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const) return false; nir_load_const_instr *load = nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr); switch (nir_op_infos[instr->op].input_types[src]) { case nir_type_float: for (unsigned i = 0; i < num_components; ++i) { if (load->value.f[new_swizzle[i]] != const_val->data.f) return false; } return true; case nir_type_int: case nir_type_unsigned: case nir_type_bool: for (unsigned i = 0; i < num_components; ++i) { if (load->value.i[new_swizzle[i]] != const_val->data.i) return false; } return true; default: unreachable("Invalid alu source type"); } } default: unreachable("Invalid search value type"); } }