Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
sha256.cpp
Go to the documentation of this file.
2
3#include <algorithm>
4#include <array>
5#include <vector>
6
9
10namespace bb::avm2::simulation {
11
12namespace {
13
14// constants come from barretenberg/cpp/src/barretenberg/crypto/sha256/sha256.cpp
15constexpr std::array<uint32_t, 64> ROUND_CONSTANTS{
16 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
17 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
18 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
19 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
20 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
21 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
22 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
23 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
24};
25
26} // namespace
27
43MemoryValue Sha256::ror(const MemoryValue& x, uint8_t shift)
44{
45 auto val = x.as<uint32_t>();
46 // In a rotation, we decompose into a lhs and rhs (or hi and lo) part.
47 uint32_t lo = val & ((static_cast<uint32_t>(1) << shift) - 1);
48 uint32_t hi = val >> shift;
49 uint32_t result = (lo << (32U - shift)) | hi;
50
51 // Do this outside of an assert, in case this gets built without assert
52 bool lo_in_range = gt.gt(static_cast<uint32_t>(1) << shift, lo); // Ensure the lower bits are in range
53 BB_ASSERT(lo_in_range, "Low Value in ROR out of range");
54 return MemoryValue::from<uint32_t>(result);
55}
56
70MemoryValue Sha256::shr(const MemoryValue& x, uint8_t shift)
71{
72 uint32_t input = x.as<uint32_t>();
73 // Get the lower shift bits
74 uint32_t lo = input & ((static_cast<uint32_t>(1) << shift) - 1);
75 uint32_t hi = input >> shift;
76
77 // Do this outside of an assert, in case this gets built without assert
78 bool lo_in_range = gt.gt(static_cast<uint32_t>(1) << shift, lo); // Ensure the lower bits are in range
79 BB_ASSERT(lo_in_range, "Low Value in SHR out of range");
80
81 return MemoryValue::from<uint32_t>(hi);
82}
83
94{
95 uint64_t sum = 0;
96 for (const auto& value : values) {
97 // This is safe, since we've already checked that the values are of tag U32
98 sum += value.as<uint32_t>();
99 }
100 uint32_t lo = static_cast<uint32_t>(sum);
101 uint32_t hi = static_cast<uint32_t>(sum >> 32);
102
103 // Range-check lo via GT (matches PIL RANGE_COMP_*_RHS lookups).
104 bool lo_in_range =
105 gt.gt(static_cast<uint64_t>(1) << 32, static_cast<uint64_t>(lo)); // Ensure the lower bits are in range
106 // hi is range-checked in PIL via boolean constraint (output) or range-8 lookup (compression),
107 // not via GT. We only assert here for debug purposes.
108 BB_ASSERT(lo_in_range, "Low value in MODULO_SUM out of range");
109 BB_ASSERT(hi < 256, "High value in MODULO_SUM out of range");
110 return MemoryValue::from<uint32_t>(lo);
111}
112
136 MemoryAddress state_addr,
137 MemoryAddress input_addr,
138 MemoryAddress output_addr)
139{
140 uint32_t execution_clk = execution_id_manager.get_execution_id();
141 uint16_t space_id = memory.get_space_id();
142
143 // Default values are FF(0) as that is what the circuit would expect
145 state.fill(MemoryValue::from<FF>(0));
146
148 input.reserve(16);
149
150 // Check that the maximum addresss for the state, input, and output addresses are within the valid range.
151 // (1) Read the 8 element hash state from { state_addr, state_addr + 1, ..., state_addr + 7 }
152 // (2) Read the 16 element input from { input_addr, input_addr + 1, ..., input_addr + 15 }
153 // (3) Write the 8 element output to { output_addr, output_addr + 1, ..., output_addr + 7 }
154 bool state_addr_out_of_range = gt.gt(static_cast<uint64_t>(state_addr) + 7, AVM_HIGHEST_MEM_ADDRESS);
155 bool input_addr_out_of_range = gt.gt(static_cast<uint64_t>(input_addr) + 15, AVM_HIGHEST_MEM_ADDRESS);
156 bool output_addr_out_of_range = gt.gt(static_cast<uint64_t>(output_addr) + 7, AVM_HIGHEST_MEM_ADDRESS);
157
158 try {
159 if (state_addr_out_of_range || input_addr_out_of_range || output_addr_out_of_range) {
160 throw Sha256CompressionException("Memory address out of range for sha256 compression.");
161 }
162
163 // Read the hash state from memory. The state needs to be loaded atomically from memory (i.e. all 8 elements are
164 // read regardless of errors)
165 for (uint32_t i = 0; i < 8; ++i) {
166 state[i] = memory.get(state_addr + i);
167 }
168
169 // If any of the state values are not of tag U32, we throw an error.
170 if (std::ranges::any_of(state, [](const MemoryValue& val) { return val.get_tag() != MemoryTag::U32; })) {
171 throw Sha256CompressionException("Invalid tag for sha256 state values.");
172 }
173
174 // Load 16 elements representing the hash input from memory.
175 // Since the circuit loads this per row, we throw on the first error we find.
176 for (uint32_t i = 0; i < 16; ++i) {
177 input.emplace_back(memory.get(input_addr + i));
178 if (input[i].get_tag() != MemoryTag::U32) {
179 throw Sha256CompressionException("Invalid tag for sha256 input values.");
180 }
181 }
182
183 // Perform sha256 compression. Taken from `vm2/simulation/lib/sha256_compression.cpp` but using
184 // the bitwise operations and MemoryValues
186
187 // Fill first 16 words with the inputs
188 for (size_t i = 0; i < 16; ++i) {
189 w[i] = input[i];
190 }
191
192 // Extend the input data into the remaining 48 words
193 for (size_t i = 16; i < 64; ++i) {
194 MemoryValue s0 =
195 bitwise.xor_op(bitwise.xor_op(ror(w[i - 15], 7U), ror(w[i - 15], 18U)), shr(w[i - 15], 3U));
196 MemoryValue s1 = bitwise.xor_op(bitwise.xor_op(ror(w[i - 2], 17U), ror(w[i - 2], 19U)), shr(w[i - 2], 10U));
197 // Could be explicit with an std::initializer_list<uint32_t> here, the array overload is more readable imo.
198 // std::spans are annoying to construct from literals
199 // (https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2447r2.html)
200 w[i] = modulo_sum({ { w[i - 16], w[i - 7], s0, s1 } });
201 }
202
203 // Initialize round variables with previous block output
204 MemoryValue a = state[0];
205 MemoryValue b = state[1];
206 MemoryValue c = state[2];
207 MemoryValue d = state[3];
208 MemoryValue e = state[4];
209 MemoryValue f = state[5];
210 MemoryValue g = state[6];
211 MemoryValue h = state[7];
212
213 // Apply SHA-256 compression function to the message schedule
214 for (size_t i = 0; i < 64; ++i) {
215 MemoryValue S1 = bitwise.xor_op(bitwise.xor_op(ror(e, 6U), ror(e, 11U)), ror(e, 25U));
216 MemoryValue ch = bitwise.xor_op(bitwise.and_op(e, f), bitwise.and_op(~e, g));
217 MemoryValue S0 = bitwise.xor_op(bitwise.xor_op(ror(a, 2U), ror(a, 13U)), ror(a, 22U));
218 MemoryValue maj =
219 bitwise.xor_op(bitwise.xor_op(bitwise.and_op(a, b), bitwise.and_op(a, c)), bitwise.and_op(b, c));
220
221 auto prev_h = h; // Need to store the previous h value before updating it so we can use it in the modulo sum
222 h = g;
223 g = f;
224 f = e;
225 // e = d + temp1;
226 e = modulo_sum({ { d, prev_h, S1, ch, MemoryValue::from<uint32_t>(ROUND_CONSTANTS[i]), w[i] } });
227 d = c;
228 c = b;
229 b = a;
230 // a = temp1 + temp2;
231 a = modulo_sum({ { prev_h, S1, ch, MemoryValue::from<uint32_t>(ROUND_CONSTANTS[i]), w[i], S0, maj } });
232 }
233
234 // Add into previous block output and return
236 modulo_sum({ { a, state[0] } }), modulo_sum({ { b, state[1] } }), modulo_sum({ { c, state[2] } }),
237 modulo_sum({ { d, state[3] } }), modulo_sum({ { e, state[4] } }), modulo_sum({ { f, state[5] } }),
238 modulo_sum({ { g, state[6] } }), modulo_sum({ { h, state[7] } }),
239 };
240
241 // Write the output back to memory.
242 for (uint32_t i = 0; i < 8; ++i) {
243 memory.set(output_addr + i, output[i]);
244 }
245
246 events.emit({ .execution_clk = execution_clk,
247 .space_id = space_id,
248 .state_addr = state_addr,
249 .input_addr = input_addr,
250 .output_addr = output_addr,
251 .state = state,
252 .input = input,
253 .output = output });
254 } catch (const Sha256CompressionException& e) {
255 // If any error occurs, we emit an event with the error message.
257 output.fill(MemoryValue::from<FF>(0)); // Default output in case of error
258 events.emit({ .execution_clk = execution_clk,
259 .space_id = space_id,
260 .state_addr = state_addr,
261 .input_addr = input_addr,
262 .output_addr = output_addr,
263 .state = state,
264 .input = input,
265 .output = output });
266
267 // Rethrow the exception after emitting the event
268 throw;
269 }
270}
271
272} // namespace bb::avm2::simulation
#define BB_ASSERT(expression,...)
Definition assert.hpp:70
#define AVM_HIGHEST_MEM_ADDRESS
ValueTag get_tag() const
virtual uint32_t get_execution_id() const =0
MemoryValue modulo_sum(std::span< const MemoryValue > values)
Sum a span of U32 MemoryValues and return the result modulo 2^32.
Definition sha256.cpp:93
EventEmitterInterface< Sha256CompressionEvent > & events
Definition sha256.hpp:43
void compression(MemoryInterface &memory, MemoryAddress state_addr, MemoryAddress input_addr, MemoryAddress output_addr) override
Execute the SHA-256 compression function: read state and input from memory, compress,...
Definition sha256.cpp:135
MemoryValue shr(const MemoryValue &x, uint8_t shift)
Perform a 32-bit right shift on a MemoryValue.
Definition sha256.cpp:70
ExecutionIdGetterInterface & execution_id_manager
Definition sha256.hpp:40
MemoryValue ror(const MemoryValue &x, uint8_t shift)
Perform a 32-bit right rotation on a MemoryValue.
Definition sha256.cpp:43
FF a
FF b
AVM range check gadget for witness generation.
uint32_t MemoryAddress
Inner sum(Cont< Inner, Args... > const &in)
Definition container.hpp:70
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13