Grok  7.6.0
WaveletForward.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2016-2020 Grok Image Compression Inc.
3  *
4  * This source code is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU Affero General Public License, version 3,
6  * as published by the Free Software Foundation.
7  *
8  * This source code is distributed in the hope that it will be useful,
9  * but WITHOUT ANY WARRANTY; without even the implied warranty of
10  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11  * GNU Affero General Public License for more details.
12  *
13  * You should have received a copy of the GNU Affero General Public License
14  * along with this program. If not, see <http://www.gnu.org/licenses/>.
15  *
16  */
17 
18 #pragma once
19 
20 #include "grok_includes.h"
21 
22 namespace grk {
23 
24 template <typename DWT> class WaveletForward
25 {
26 
27 public:
32  bool run(TileComponent *tilec);
33 };
34 
35 
40 template <typename DWT> bool WaveletForward<DWT>::run(TileComponent *tilec){
41  if (tilec->numresolutions == 1U)
42  return true;
43 
44  size_t l_data_size = dwt_utils::max_resolution(tilec->resolutions,
45  tilec->numresolutions) * sizeof(int32_t);
46  /* overflow check */
47  if (l_data_size > SIZE_MAX) {
48  GROK_ERROR("Wavelet compress: overflow");
49  return false;
50  }
51  if (!l_data_size)
52  return false;
53 
54  bool rc = true;
55  uint32_t rw,rh,rw_next,rh_next;
56  uint8_t cas_row,cas_col;
57  uint32_t stride = tilec->buf->stride();
58  uint32_t num_decomps = (uint32_t) (tilec->numresolutions - 1);
59  auto a = tilec->buf->ptr();
60  auto cur_res = tilec->resolutions + num_decomps;
61  auto next_res = cur_res - 1;
62 
63  auto bj_array = new int32_t*[ThreadPool::get()->num_threads()];
64  for (uint32_t i = 0; i < ThreadPool::get()->num_threads(); ++i){
65  bj_array[i] = nullptr;
66  }
67  for (uint32_t i = 0; i < ThreadPool::get()->num_threads(); ++i){
68  bj_array[i] = (int32_t*)grk_aligned_malloc(l_data_size);
69  if (!bj_array[i]){
70  rc = false;
71  goto cleanup;
72  }
73  }
74 
75  for (uint32_t decompno = 0; decompno < num_decomps; ++decompno) {
76 
77  /* width of the resolution level computed */
78  rw = cur_res->x1 - cur_res->x0;
79  /* height of the resolution level computed */
80  rh = cur_res->y1 - cur_res->y0;
81  // width of the next resolution level
82  rw_next = next_res->x1 - next_res->x0;
83  //height of the next resolution level
84  rh_next = next_res->y1 - next_res->y0;
85 
86  /* 0 = non inversion on horizontal filtering 1 = inversion between low-pass and high-pass filtering */
87  cas_row = cur_res->x0 & 1;
88  /* 0 = non inversion on vertical filtering 1 = inversion between low-pass and high-pass filtering */
89  cas_col = cur_res->y0 & 1;
90 
91  // transform vertical
92  if (rw) {
93  const uint32_t linesPerThreadV = static_cast<uint32_t>(std::ceil((float)rw / (float)ThreadPool::get()->num_threads()));
94  const uint32_t s_n = rh_next;
95  const uint32_t d_n = rh - rh_next;
96  if (ThreadPool::get()->num_threads() == 1){
97  DWT wavelet;
98  for (auto m = 0U;m < std::min<uint32_t>(linesPerThreadV, rw); ++m) {
99  auto bj = bj_array[0];
100  auto aj = a + m;
101  for (uint32_t k = 0; k < rh; ++k)
102  bj[k] = aj[k * stride];
103  wavelet.encode_line(bj, (int32_t)d_n, (int32_t)s_n, cas_col);
104  dwt_utils::deinterleave_v(bj, aj, d_n, s_n, stride, cas_col);
105  }
106  } else {
107  std::vector< std::future<int> > results;
108  for(uint32_t i = 0; i < ThreadPool::get()->num_threads(); ++i) {
109  uint32_t index = i;
110  results.emplace_back(
111  ThreadPool::get()->enqueue([index, bj_array,a,
112  stride, rw,rh,
113  d_n, s_n, cas_col,
114  linesPerThreadV] {
115  DWT wavelet;
116  for (uint32_t m = index * linesPerThreadV;
117  m < std::min<uint32_t>((index+1)*linesPerThreadV, rw); ++m) {
118  auto bj = bj_array[index];
119  auto aj = a + m;
120  for (uint32_t k = 0; k < rh; ++k)
121  bj[k] = aj[k * stride];
122  wavelet.encode_line(bj, (int32_t)d_n, (int32_t)s_n, cas_col);
123  dwt_utils::deinterleave_v(bj, aj, d_n, s_n, stride, cas_col);
124  }
125  return 0;
126  })
127  );
128  }
129  for(auto &result: results)
130  result.get();
131  }
132  }
133 
134  // transform horizontal
135  if (rh){
136  const uint32_t s_n = rw_next;
137  const uint32_t d_n = rw - rw_next;
138  const uint32_t linesPerThreadH = static_cast<uint32_t>(std::ceil((float)rh / (float)ThreadPool::get()->num_threads()));
139  if (ThreadPool::get()->num_threads() == 1){
140  DWT wavelet;
141  for (auto m = 0U;m < std::min<uint32_t>(linesPerThreadH, rh); ++m) {
142  auto bj = bj_array[0];
143  auto aj = a + m * stride;
144  memcpy(bj,aj,rw << 2);
145  wavelet.encode_line(bj, (int32_t)d_n, (int32_t)s_n, cas_row);
146  dwt_utils::deinterleave_h(bj, aj, d_n, s_n, cas_row);
147  }
148 
149  } else {
150  std::vector< std::future<int> > results;
151  for(uint32_t i = 0; i < ThreadPool::get()->num_threads(); ++i) {
152  uint32_t index = i;
153  results.emplace_back(
154  ThreadPool::get()->enqueue([index, bj_array,a,
155  stride, rw,rh,
156  d_n, s_n, cas_row,
157  linesPerThreadH] {
158  DWT wavelet;
159  for (auto m = index * linesPerThreadH;
160  m < std::min<uint32_t>((index+1)*linesPerThreadH, rh); ++m) {
161  int32_t *bj = bj_array[index];
162  int32_t *aj = a + m * stride;
163  memcpy(bj,aj,rw << 2);
164  wavelet.encode_line(bj, (int32_t)d_n, (int32_t)s_n, cas_row);
165  dwt_utils::deinterleave_h(bj, aj, d_n, s_n, cas_row);
166  }
167  return 0;
168  })
169  );
170  }
171  for(auto &result: results)
172  result.get();
173  }
174  }
175  cur_res = next_res;
176  next_res--;
177  }
178 cleanup:
179  for (uint32_t i = 0; i < ThreadPool::get()->num_threads(); ++i)
180  grk_aligned_free(bj_array[i]);
181  delete[] bj_array;
182  return rc;
183 }
184 
185 }
grk::grk_aligned_free
void grk_aligned_free(void *ptr)
grok_includes.h
grk::WaveletForward
Definition: WaveletForward.h:25
grk::TileComponentBuffer::ptr
T * ptr(uint32_t resno, uint32_t bandno) const
Get pointer to band buffer.
Definition: TileComponentBuffer.h:170
grk::TileComponentBuffer::stride
uint32_t stride(uint32_t resno, uint32_t bandno) const
Get stride of band buffer.
Definition: TileComponentBuffer.h:221
grk::TileComponent::buf
TileComponentBuffer< int32_t > * buf
Definition: TileComponent.h:65
grk::TileComponent::resolutions
grk_resolution * resolutions
Definition: TileComponent.h:60
grk
Copyright (C) 2016-2020 Grok Image Compression Inc.
Definition: BitIO.h:27
grk::WaveletForward::run
bool run(TileComponent *tilec)
Forward wavelet transform in 2-D.
Definition: WaveletForward.h:40
grk::TileComponent
Definition: TileComponent.h:31
grk::dwt_utils::deinterleave_v
static void deinterleave_v(int32_t *a, int32_t *b, uint32_t d_n, uint32_t s_n, uint32_t stride, int32_t cas)
grk::TileComponent::numresolutions
uint32_t numresolutions
Definition: TileComponent.h:57
grk::grk_aligned_malloc
void * grk_aligned_malloc(size_t size)
Allocate memory aligned to a 16 byte boundary.
grk::dwt_utils::max_resolution
static uint32_t max_resolution(grk_resolution *GRK_RESTRICT r, uint32_t i)
grk::dwt_utils::deinterleave_h
static void deinterleave_h(int32_t *a, int32_t *b, uint32_t d_n, uint32_t s_n, int32_t cas)
grk::GROK_ERROR
void GROK_ERROR(const char *fmt,...)