示例#1
0
static struct vtn_ssa_value *
matrix_multiply(struct vtn_builder *b,
                struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
{

   struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
   struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
   struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
   struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);

   unsigned src0_rows = glsl_get_vector_elements(src0->type);
   unsigned src0_columns = glsl_get_matrix_columns(src0->type);
   unsigned src1_columns = glsl_get_matrix_columns(src1->type);

   const struct glsl_type *dest_type;
   if (src1_columns > 1) {
      dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
                                   src0_rows, src1_columns);
   } else {
      dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
   }
   struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);

   dest = wrap_matrix(b, dest);

   bool transpose_result = false;
   if (src0_transpose && src1_transpose) {
      /* transpose(A) * transpose(B) = transpose(B * A) */
      src1 = src0_transpose;
      src0 = src1_transpose;
      src0_transpose = NULL;
      src1_transpose = NULL;
      transpose_result = true;
   }

   if (src0_transpose && !src1_transpose &&
       glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
      /* We already have the rows of src0 and the columns of src1 available,
       * so we can just take the dot product of each row with each column to
       * get the result.
       */

      for (unsigned i = 0; i < src1_columns; i++) {
         nir_ssa_def *vec_src[4];
         for (unsigned j = 0; j < src0_rows; j++) {
            vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
                                          src1->elems[i]->def);
         }
         dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
      }
   } else {
      /* We don't handle the case where src1 is transposed but not src0, since
       * the general case only uses individual components of src1 so the
       * optimizer should chew through the transpose we emitted for src1.
       */

      for (unsigned i = 0; i < src1_columns; i++) {
         /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
         dest->elems[i]->def =
            nir_fmul(&b->nb, src0->elems[0]->def,
                     nir_channel(&b->nb, src1->elems[i]->def, 0));
         for (unsigned j = 1; j < src0_columns; j++) {
            dest->elems[i]->def =
               nir_fadd(&b->nb, dest->elems[i]->def,
                        nir_fmul(&b->nb, src0->elems[j]->def,
                                 nir_channel(&b->nb, src1->elems[i]->def, j)));
         }
      }
   }

   dest = unwrap_matrix(dest);

   if (transpose_result)
      dest = vtn_ssa_transpose(b, dest);

   return dest;
}
示例#2
0
static void
_vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
                      nir_ssa_def *index, nir_ssa_def *offset,
                      struct vtn_access_chain *chain, unsigned chain_idx,
                      struct vtn_type *type, struct vtn_ssa_value **inout)
{
   if (chain && chain_idx >= chain->length)
      chain = NULL;

   if (load && chain == NULL && *inout == NULL)
      *inout = vtn_create_ssa_value(b, type->type);

   enum glsl_base_type base_type = glsl_get_base_type(type->type);
   switch (base_type) {
   case GLSL_TYPE_UINT:
   case GLSL_TYPE_INT:
   case GLSL_TYPE_FLOAT:
   case GLSL_TYPE_BOOL:
      /* This is where things get interesting.  At this point, we've hit
       * a vector, a scalar, or a matrix.
       */
      if (glsl_type_is_matrix(type->type)) {
         if (chain == NULL) {
            /* Loading the whole matrix */
            struct vtn_ssa_value *transpose;
            unsigned num_ops, vec_width;
            if (type->row_major) {
               num_ops = glsl_get_vector_elements(type->type);
               vec_width = glsl_get_matrix_columns(type->type);
               if (load) {
                  const struct glsl_type *transpose_type =
                     glsl_matrix_type(base_type, vec_width, num_ops);
                  *inout = vtn_create_ssa_value(b, transpose_type);
               } else {
                  transpose = vtn_ssa_transpose(b, *inout);
                  inout = &transpose;
               }
            } else {
               num_ops = glsl_get_matrix_columns(type->type);
               vec_width = glsl_get_vector_elements(type->type);
            }

            for (unsigned i = 0; i < num_ops; i++) {
               nir_ssa_def *elem_offset =
                  nir_iadd(&b->nb, offset,
                           nir_imm_int(&b->nb, i * type->stride));
               _vtn_load_store_tail(b, op, load, index, elem_offset,
                                    &(*inout)->elems[i],
                                    glsl_vector_type(base_type, vec_width));
            }

            if (load && type->row_major)
               *inout = vtn_ssa_transpose(b, *inout);
         } else if (type->row_major) {
            /* Row-major but with an access chiain. */
            nir_ssa_def *col_offset =
               vtn_access_link_as_ssa(b, chain->link[chain_idx],
                                      type->array_element->stride);
            offset = nir_iadd(&b->nb, offset, col_offset);

            if (chain_idx + 1 < chain->length) {
               /* Picking off a single element */
               nir_ssa_def *row_offset =
                  vtn_access_link_as_ssa(b, chain->link[chain_idx + 1],
                                         type->stride);
               offset = nir_iadd(&b->nb, offset, row_offset);
               if (load)
                  *inout = vtn_create_ssa_value(b, glsl_scalar_type(base_type));
               _vtn_load_store_tail(b, op, load, index, offset, inout,
                                    glsl_scalar_type(base_type));
            } else {
               /* Grabbing a column; picking one element off each row */
               unsigned num_comps = glsl_get_vector_elements(type->type);
               const struct glsl_type *column_type =
                  glsl_get_column_type(type->type);

               nir_ssa_def *comps[4];
               for (unsigned i = 0; i < num_comps; i++) {
                  nir_ssa_def *elem_offset =
                     nir_iadd(&b->nb, offset,
                              nir_imm_int(&b->nb, i * type->stride));

                  struct vtn_ssa_value *comp, temp_val;
                  if (!load) {
                     temp_val.def = nir_channel(&b->nb, (*inout)->def, i);
                     temp_val.type = glsl_scalar_type(base_type);
                  }
                  comp = &temp_val;
                  _vtn_load_store_tail(b, op, load, index, elem_offset,
                                       &comp, glsl_scalar_type(base_type));
                  comps[i] = comp->def;
               }

               if (load) {
                  if (*inout == NULL)
                     *inout = vtn_create_ssa_value(b, column_type);

                  (*inout)->def = nir_vec(&b->nb, comps, num_comps);
               }
            }
         } else {
            /* Column-major with a deref. Fall through to array case. */
            nir_ssa_def *col_offset =
               vtn_access_link_as_ssa(b, chain->link[chain_idx], type->stride);
            offset = nir_iadd(&b->nb, offset, col_offset);

            _vtn_block_load_store(b, op, load, index, offset,
                                  chain, chain_idx + 1,
                                  type->array_element, inout);
         }
      } else if (chain == NULL) {
         /* Single whole vector */
         assert(glsl_type_is_vector_or_scalar(type->type));
         _vtn_load_store_tail(b, op, load, index, offset, inout, type->type);
      } else {
         /* Single component of a vector. Fall through to array case. */
         nir_ssa_def *elem_offset =
            vtn_access_link_as_ssa(b, chain->link[chain_idx], type->stride);
         offset = nir_iadd(&b->nb, offset, elem_offset);

         _vtn_block_load_store(b, op, load, index, offset, NULL, 0,
                               type->array_element, inout);
      }
      return;

   case GLSL_TYPE_ARRAY: {
      unsigned elems = glsl_get_length(type->type);
      for (unsigned i = 0; i < elems; i++) {
         nir_ssa_def *elem_off =
            nir_iadd(&b->nb, offset, nir_imm_int(&b->nb, i * type->stride));
         _vtn_block_load_store(b, op, load, index, elem_off, NULL, 0,
                               type->array_element, &(*inout)->elems[i]);
      }
      return;
   }

   case GLSL_TYPE_STRUCT: {
      unsigned elems = glsl_get_length(type->type);
      for (unsigned i = 0; i < elems; i++) {
         nir_ssa_def *elem_off =
            nir_iadd(&b->nb, offset, nir_imm_int(&b->nb, type->offsets[i]));
         _vtn_block_load_store(b, op, load, index, elem_off, NULL, 0,
                               type->members[i], &(*inout)->elems[i]);
      }
      return;
   }

   default:
      unreachable("Invalid block member type");
   }
}