forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinterval.cpp
More file actions
133 lines (111 loc) · 3.63 KB
/
interval.cpp
File metadata and controls
133 lines (111 loc) · 3.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/interval.hpp"
using namespace ngraph;
namespace {
Interval::value_type clip(Interval::value_type value) {
return std::max(Interval::value_type(0), std::min(Interval::s_max, value));
}
Interval::value_type clip_times(Interval::value_type a, Interval::value_type b) {
if (a == 0 || b == 0) {
return 0;
} else if (a == Interval::s_max || b == Interval::s_max || a > Interval::s_max / b) {
return Interval::s_max;
} else {
return a * b;
}
}
Interval::value_type clip_add(Interval::value_type a, Interval::value_type b) {
if (a == Interval::s_max || b == Interval::s_max) {
return Interval::s_max;
}
// check overflow without undefined behavior: a + b <= max
const static auto max = std::numeric_limits<Interval::value_type>::max();
if (b > (max - a)) {
return Interval::s_max;
}
return a + b;
}
Interval::value_type clip_minus(Interval::value_type a, Interval::value_type b) {
if (a <= b) {
return 0;
}
if (a == Interval::s_max) {
return Interval::s_max;
}
return a - b;
}
} // namespace
void Interval::canonicalize() {
if (m_max_val < m_min_val) {
m_min_val = s_max;
m_max_val = s_max;
} else {
m_min_val = clip(m_min_val);
m_max_val = clip(m_max_val);
}
}
Interval::Interval(value_type min_val, value_type max_val) : m_min_val(min_val), m_max_val(max_val) {
canonicalize();
}
Interval::Interval(value_type val) {
m_min_val = clip(val);
m_max_val = m_min_val;
}
bool Interval::operator==(const Interval& interval) const {
return m_min_val == interval.m_min_val && m_max_val == interval.m_max_val;
}
bool Interval::operator!=(const Interval& interval) const {
return !(*this == interval);
}
Interval Interval::operator+(const Interval& interval) const {
if (empty() || interval.empty()) {
return Interval(s_max);
}
return Interval(clip_add(m_min_val, interval.m_min_val), clip_add(m_max_val, interval.m_max_val));
}
Interval& Interval::operator+=(const Interval& interval) {
return *this = *this + interval;
}
Interval Interval::operator-(const Interval& interval) const {
if (empty() || interval.empty()) {
return Interval(s_max);
}
return Interval(clip_minus(m_min_val, interval.m_max_val), clip_minus(m_max_val, interval.m_min_val));
}
Interval& Interval::operator-=(const Interval& interval) {
return *this = *this - interval;
}
Interval Interval::operator*(const Interval& interval) const {
if (empty()) {
return *this;
}
if (interval.empty()) {
return interval;
}
return Interval(clip_times(m_min_val, interval.m_min_val), clip_times(m_max_val, interval.m_max_val));
}
Interval& Interval::operator*=(const Interval& interval) {
return *this = *this * interval;
}
Interval Interval::operator&(const Interval& interval) const {
return Interval(std::max(m_min_val, interval.m_min_val), std::min(m_max_val, interval.m_max_val));
}
Interval& Interval::operator&=(const Interval& interval) {
return *this = *this & interval;
}
bool Interval::contains(const Interval& interval) const {
return contains(interval.m_min_val) && contains(interval.m_max_val);
}
constexpr Interval::value_type Interval::s_max;
std::ostream& ov::operator<<(std::ostream& str, const Interval& interval) {
str << "Interval(" << interval.get_min_val() << ", ";
auto max_val = interval.get_max_val();
if (max_val == Interval::s_max) {
str << "...";
} else {
str << max_val;
}
return str << ")";
}