// This function encodes the reference frame
static void write_ref_frames(const VP9_COMMON *cm, const MACROBLOCKD *xd,
                             vpx_writer *w) {
  const MODE_INFO *const mi = xd->mi[0];
  const int is_compound = has_second_ref(mi);
  const int segment_id = mi->segment_id;

  // If segment level coding of this signal is disabled...
  // or the segment allows multiple reference frame options
  if (segfeature_active(&cm->seg, segment_id, SEG_LVL_REF_FRAME)) {
    assert(!is_compound);
    assert(mi->ref_frame[0] ==
               get_segdata(&cm->seg, segment_id, SEG_LVL_REF_FRAME));
  } else {
    // does the feature use compound prediction or not
    // (if not specified at the frame/segment level)
    if (cm->reference_mode == REFERENCE_MODE_SELECT) {
      vpx_write(w, is_compound, vp9_get_reference_mode_prob(cm, xd));
    } else {
      assert(!is_compound == (cm->reference_mode == SINGLE_REFERENCE));
    }

    if (is_compound) {
      vpx_write(w, mi->ref_frame[0] == GOLDEN_FRAME,
                vp9_get_pred_prob_comp_ref_p(cm, xd));
    } else {
      const int bit0 = mi->ref_frame[0] != LAST_FRAME;
      vpx_write(w, bit0, vp9_get_pred_prob_single_ref_p1(cm, xd));
      if (bit0) {
        const int bit1 = mi->ref_frame[0] != GOLDEN_FRAME;
        vpx_write(w, bit1, vp9_get_pred_prob_single_ref_p2(cm, xd));
      }
    }
  }
}
static void write_selected_tx_size(const VP9_COMMON *cm,
                                   const MACROBLOCKD *xd, vpx_writer *w) {
  TX_SIZE tx_size = xd->mi[0]->tx_size;
  BLOCK_SIZE bsize = xd->mi[0]->sb_type;
  const TX_SIZE max_tx_size = max_txsize_lookup[bsize];
  const vpx_prob *const tx_probs = get_tx_probs2(max_tx_size, xd,
                                                 &cm->fc->tx_probs);
  vpx_write(w, tx_size != TX_4X4, tx_probs[0]);
  if (tx_size != TX_4X4 && max_tx_size >= TX_16X16) {
    vpx_write(w, tx_size != TX_8X8, tx_probs[1]);
    if (tx_size != TX_8X8 && max_tx_size >= TX_32X32)
      vpx_write(w, tx_size != TX_16X16, tx_probs[2]);
  }
}
static int write_skip(const VP9_COMMON *cm, const MACROBLOCKD *xd,
                      int segment_id, const MODE_INFO *mi, vpx_writer *w) {
  if (segfeature_active(&cm->seg, segment_id, SEG_LVL_SKIP)) {
    return 1;
  } else {
    const int skip = mi->skip;
    vpx_write(w, skip, vp9_get_skip_prob(cm, xd));
    return skip;
  }
}
static void write_partition(const VP9_COMMON *const cm,
                            const MACROBLOCKD *const xd,
                            int hbs, int mi_row, int mi_col,
                            PARTITION_TYPE p, BLOCK_SIZE bsize, vpx_writer *w) {
  const int ctx = partition_plane_context(xd, mi_row, mi_col, bsize);
  const vpx_prob *const probs = xd->partition_probs[ctx];
  const int has_rows = (mi_row + hbs) < cm->mi_rows;
  const int has_cols = (mi_col + hbs) < cm->mi_cols;

  if (has_rows && has_cols) {
    vp9_write_token(w, vp9_partition_tree, probs, &partition_encodings[p]);
  } else if (!has_rows && has_cols) {
    assert(p == PARTITION_SPLIT || p == PARTITION_HORZ);
    vpx_write(w, p == PARTITION_SPLIT, probs[1]);
  } else if (has_rows && !has_cols) {
    assert(p == PARTITION_SPLIT || p == PARTITION_VERT);
    vpx_write(w, p == PARTITION_SPLIT, probs[2]);
  } else {
    assert(p == PARTITION_SPLIT);
  }
}
void encode_with_adaptive_probability() {
    memcpy(tmp, uncompressed, sizeof(uncompressed));
    (*transform)(tmp); // this currently is a no-op but it may be helpful for the EXERCISE
    DynProb encode;
    vpx_writer wri ={0};
    vpx_start_encode(&wri, tmp);
    for (size_t i = 0; i < sizeof(uncompressed); ++i) {
        for(int bit = 1; bit < 256; bit <<= 1) {
            bool cur_bit = !!(tmp[i] & bit);
            vpx_write(&wri, cur_bit, encode.prob);
            encode.record_bit(cur_bit); // <-- this a new line for lesson1 that lets the encoder adapt to data
        }
    }
    vpx_stop_encode(&wri);
    printf("Buffer encoded with final prob(0) = %.2f results in %d size (%.2f%%)\n",
           encode.prob / 255.,
           wri.pos,
           100 * wri.pos / float(sizeof(uncompressed)));
    DynProb decode;
    vpx_reader rea={0};
    vpx_reader_init(&rea,
                    wri.buffer,
                    wri.pos);
    memset(roundtrip, 0, sizeof(roundtrip));
    for (size_t i = 0; i < sizeof(roundtrip); ++i) {
        for(int bit = 1; bit < 256; bit <<= 1) {
            if (vpx_read(&rea, decode.prob)) {
                roundtrip[i] |= bit;
                decode.record_bit(true); // <-- this a new line for lesson1
            } else {
                decode.record_bit(false); // <-- this a new line for lesson1
            }
        }
    }
    assert(vpx_reader_has_error(&rea) == 0);
    (*untransform)(uncompressed); // this is, again a no-op, but may be helpful for the EXERCISE
    assert(memcmp(uncompressed, roundtrip, sizeof(uncompressed)) == 0);
}
static void update_coef_probs_common(vpx_writer* const bc, VP9_COMP *cpi,
                                     TX_SIZE tx_size,
                                     vp9_coeff_stats *frame_branch_ct,
                                     vp9_coeff_probs_model *new_coef_probs) {
  vp9_coeff_probs_model *old_coef_probs = cpi->common.fc->coef_probs[tx_size];
  const vpx_prob upd = DIFF_UPDATE_PROB;
  const int entropy_nodes_update = UNCONSTRAINED_NODES;
  int i, j, k, l, t;
  int stepsize = cpi->sf.coeff_prob_appx_step;

  switch (cpi->sf.use_fast_coef_updates) {
    case TWO_LOOP: {
      /* dry run to see if there is any update at all needed */
      int savings = 0;
      int update[2] = {0, 0};
      for (i = 0; i < PLANE_TYPES; ++i) {
        for (j = 0; j < REF_TYPES; ++j) {
          for (k = 0; k < COEF_BANDS; ++k) {
            for (l = 0; l < BAND_COEFF_CONTEXTS(k); ++l) {
              for (t = 0; t < entropy_nodes_update; ++t) {
                vpx_prob newp = new_coef_probs[i][j][k][l][t];
                const vpx_prob oldp = old_coef_probs[i][j][k][l][t];
                int s;
                int u = 0;
                if (t == PIVOT_NODE)
                  s = vp9_prob_diff_update_savings_search_model(
                      frame_branch_ct[i][j][k][l][0], oldp, &newp, upd,
                      stepsize);
                else
                  s = vp9_prob_diff_update_savings_search(
                      frame_branch_ct[i][j][k][l][t], oldp, &newp, upd);
                if (s > 0 && newp != oldp)
                  u = 1;
                if (u)
                  savings += s - (int)(vp9_cost_zero(upd));
                else
                  savings -= (int)(vp9_cost_zero(upd));
                update[u]++;
              }
            }
          }
        }
      }

      // printf("Update %d %d, savings %d\n", update[0], update[1], savings);
      /* Is coef updated at all */
      if (update[1] == 0 || savings < 0) {
        vpx_write_bit(bc, 0);
        return;
      }
      vpx_write_bit(bc, 1);
      for (i = 0; i < PLANE_TYPES; ++i) {
        for (j = 0; j < REF_TYPES; ++j) {
          for (k = 0; k < COEF_BANDS; ++k) {
            for (l = 0; l < BAND_COEFF_CONTEXTS(k); ++l) {
              // calc probs and branch cts for this frame only
              for (t = 0; t < entropy_nodes_update; ++t) {
                vpx_prob newp = new_coef_probs[i][j][k][l][t];
                vpx_prob *oldp = old_coef_probs[i][j][k][l] + t;
                const vpx_prob upd = DIFF_UPDATE_PROB;
                int s;
                int u = 0;
                if (t == PIVOT_NODE)
                  s = vp9_prob_diff_update_savings_search_model(
                      frame_branch_ct[i][j][k][l][0],
                      *oldp, &newp, upd, stepsize);
                else
                  s = vp9_prob_diff_update_savings_search(
                      frame_branch_ct[i][j][k][l][t],
                      *oldp, &newp, upd);
                if (s > 0 && newp != *oldp)
                  u = 1;
                vpx_write(bc, u, upd);
                if (u) {
                  /* send/use new probability */
                  vp9_write_prob_diff_update(bc, newp, *oldp);
                  *oldp = newp;
                }
              }
            }
          }
        }
      }
      return;
    }

    case ONE_LOOP_REDUCED: {
      int updates = 0;
      int noupdates_before_first = 0;
      for (i = 0; i < PLANE_TYPES; ++i) {
        for (j = 0; j < REF_TYPES; ++j) {
          for (k = 0; k < COEF_BANDS; ++k) {
            for (l = 0; l < BAND_COEFF_CONTEXTS(k); ++l) {
              // calc probs and branch cts for this frame only
              for (t = 0; t < entropy_nodes_update; ++t) {
                vpx_prob newp = new_coef_probs[i][j][k][l][t];
                vpx_prob *oldp = old_coef_probs[i][j][k][l] + t;
                int s;
                int u = 0;

                if (t == PIVOT_NODE) {
                  s = vp9_prob_diff_update_savings_search_model(
                      frame_branch_ct[i][j][k][l][0],
                      *oldp, &newp, upd, stepsize);
                } else {
                  s = vp9_prob_diff_update_savings_search(
                      frame_branch_ct[i][j][k][l][t],
                      *oldp, &newp, upd);
                }

                if (s > 0 && newp != *oldp)
                  u = 1;
                updates += u;
                if (u == 0 && updates == 0) {
                  noupdates_before_first++;
                  continue;
                }
                if (u == 1 && updates == 1) {
                  int v;
                  // first update
                  vpx_write_bit(bc, 1);
                  for (v = 0; v < noupdates_before_first; ++v)
                    vpx_write(bc, 0, upd);
                }
                vpx_write(bc, u, upd);
                if (u) {
                  /* send/use new probability */
                  vp9_write_prob_diff_update(bc, newp, *oldp);
                  *oldp = newp;
                }
              }
            }
          }
        }
      }
      if (updates == 0) {
        vpx_write_bit(bc, 0);  // no updates
      }
      return;
    }
    default:
      assert(0);
  }
}
static void pack_inter_mode_mvs(VP9_COMP *cpi, const MODE_INFO *mi,
                                vpx_writer *w) {
  VP9_COMMON *const cm = &cpi->common;
  const nmv_context *nmvc = &cm->fc->nmvc;
  const MACROBLOCK *const x = &cpi->td.mb;
  const MACROBLOCKD *const xd = &x->e_mbd;
  const struct segmentation *const seg = &cm->seg;
  const MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
  const PREDICTION_MODE mode = mi->mode;
  const int segment_id = mi->segment_id;
  const BLOCK_SIZE bsize = mi->sb_type;
  const int allow_hp = cm->allow_high_precision_mv;
  const int is_inter = is_inter_block(mi);
  const int is_compound = has_second_ref(mi);
  int skip, ref;

  if (seg->update_map) {
    if (seg->temporal_update) {
      const int pred_flag = mi->seg_id_predicted;
      vpx_prob pred_prob = vp9_get_pred_prob_seg_id(seg, xd);
      vpx_write(w, pred_flag, pred_prob);
      if (!pred_flag)
        write_segment_id(w, seg, segment_id);
    } else {
      write_segment_id(w, seg, segment_id);
    }
  }

  skip = write_skip(cm, xd, segment_id, mi, w);

  if (!segfeature_active(seg, segment_id, SEG_LVL_REF_FRAME))
    vpx_write(w, is_inter, vp9_get_intra_inter_prob(cm, xd));

  if (bsize >= BLOCK_8X8 && cm->tx_mode == TX_MODE_SELECT &&
      !(is_inter && skip)) {
    write_selected_tx_size(cm, xd, w);
  }

  if (!is_inter) {
    if (bsize >= BLOCK_8X8) {
      write_intra_mode(w, mode, cm->fc->y_mode_prob[size_group_lookup[bsize]]);
    } else {
      int idx, idy;
      const int num_4x4_w = num_4x4_blocks_wide_lookup[bsize];
      const int num_4x4_h = num_4x4_blocks_high_lookup[bsize];
      for (idy = 0; idy < 2; idy += num_4x4_h) {
        for (idx = 0; idx < 2; idx += num_4x4_w) {
          const PREDICTION_MODE b_mode = mi->bmi[idy * 2 + idx].as_mode;
          write_intra_mode(w, b_mode, cm->fc->y_mode_prob[0]);
        }
      }
    }
    write_intra_mode(w, mi->uv_mode, cm->fc->uv_mode_prob[mode]);
  } else {
    const int mode_ctx = mbmi_ext->mode_context[mi->ref_frame[0]];
    const vpx_prob *const inter_probs = cm->fc->inter_mode_probs[mode_ctx];
    write_ref_frames(cm, xd, w);

    // If segment skip is not enabled code the mode.
    if (!segfeature_active(seg, segment_id, SEG_LVL_SKIP)) {
      if (bsize >= BLOCK_8X8) {
        write_inter_mode(w, mode, inter_probs);
      }
    }

    if (cm->interp_filter == SWITCHABLE) {
      const int ctx = vp9_get_pred_context_switchable_interp(xd);
      vp9_write_token(w, vp9_switchable_interp_tree,
                      cm->fc->switchable_interp_prob[ctx],
                      &switchable_interp_encodings[mi->interp_filter]);
      ++cpi->interp_filter_selected[0][mi->interp_filter];
    } else {
      assert(mi->interp_filter == cm->interp_filter);
    }

    if (bsize < BLOCK_8X8) {
      const int num_4x4_w = num_4x4_blocks_wide_lookup[bsize];
      const int num_4x4_h = num_4x4_blocks_high_lookup[bsize];
      int idx, idy;
      for (idy = 0; idy < 2; idy += num_4x4_h) {
        for (idx = 0; idx < 2; idx += num_4x4_w) {
          const int j = idy * 2 + idx;
          const PREDICTION_MODE b_mode = mi->bmi[j].as_mode;
          write_inter_mode(w, b_mode, inter_probs);
          if (b_mode == NEWMV) {
            for (ref = 0; ref < 1 + is_compound; ++ref)
              vp9_encode_mv(cpi, w, &mi->bmi[j].as_mv[ref].as_mv,
                            &mbmi_ext->ref_mvs[mi->ref_frame[ref]][0].as_mv,
                            nmvc, allow_hp);
          }
        }
      }
    } else {
      if (mode == NEWMV) {
        for (ref = 0; ref < 1 + is_compound; ++ref)
          vp9_encode_mv(cpi, w, &mi->mv[ref].as_mv,
                        &mbmi_ext->ref_mvs[mi->ref_frame[ref]][0].as_mv, nmvc,
                        allow_hp);
      }
    }
  }
}
static void pack_mb_tokens(vpx_writer *w,
                           TOKENEXTRA **tp, const TOKENEXTRA *const stop,
                           vpx_bit_depth_t bit_depth) {
  const TOKENEXTRA *p;
  const vp9_extra_bit *const extra_bits =
#if CONFIG_VP9_HIGHBITDEPTH
    (bit_depth == VPX_BITS_12) ? vp9_extra_bits_high12 :
    (bit_depth == VPX_BITS_10) ? vp9_extra_bits_high10 :
    vp9_extra_bits;
#else
    vp9_extra_bits;
    (void) bit_depth;
#endif  // CONFIG_VP9_HIGHBITDEPTH

  for (p = *tp; p < stop && p->token != EOSB_TOKEN; ++p) {
    if (p->token == EOB_TOKEN) {
      vpx_write(w, 0, p->context_tree[0]);
      continue;
    }
    vpx_write(w, 1, p->context_tree[0]);
    while (p->token == ZERO_TOKEN) {
      vpx_write(w, 0, p->context_tree[1]);
      ++p;
      if (p == stop || p->token == EOSB_TOKEN) {
        *tp = (TOKENEXTRA*)(uintptr_t)p + (p->token == EOSB_TOKEN);
        return;
      }
    }

    {
      const int t = p->token;
      const vpx_prob *const context_tree = p->context_tree;
      assert(t != ZERO_TOKEN);
      assert(t != EOB_TOKEN);
      assert(t != EOSB_TOKEN);
      vpx_write(w, 1, context_tree[1]);
      if (t == ONE_TOKEN) {
        vpx_write(w, 0, context_tree[2]);
        vpx_write_bit(w, p->extra & 1);
      } else {  // t >= TWO_TOKEN && t < EOB_TOKEN
        const struct vp9_token *const a = &vp9_coef_encodings[t];
        const int v = a->value;
        const int n = a->len;
        const int e = p->extra;
        vpx_write(w, 1, context_tree[2]);
        vp9_write_tree(w, vp9_coef_con_tree,
                       vp9_pareto8_full[context_tree[PIVOT_NODE] - 1], v,
                       n - UNCONSTRAINED_NODES, 0);
        if (t >= CATEGORY1_TOKEN) {
          const vp9_extra_bit *const b = &extra_bits[t];
          const unsigned char *pb = b->prob;
          int v = e >> 1;
          int n = b->len;  // number of bits in v, assumed nonzero
          do {
            const int bb = (v >> --n) & 1;
            vpx_write(w, bb, *pb++);
          } while (n);
        }
        vpx_write_bit(w, e & 1);
      }
    }
  }
  *tp = (TOKENEXTRA*)(uintptr_t)p + (p->token == EOSB_TOKEN);
}