Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions concore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ class Concore{
int communication_iport = 0; // iport refers to input port
int communication_oport = 0; // oport refers to input port

bool isSafeName(const std::string& name) {
if (name.empty()) return false;
if (name == "." || name == "..") return false;
if (name.find("..") != std::string::npos) return false;
if (name.find('/') != std::string::npos || name.find('\\') != std::string::npos) return false;
return true;
}

public:
double delay = 1;
int retrycount = 0;
Expand Down Expand Up @@ -133,7 +141,7 @@ class Concore{
*/
void createSharedMemory(key_t key)
{
shmId_create = shmget(key, 256, IPC_CREAT | 0666);
shmId_create = shmget(key, 256, IPC_CREAT | 0600);

if (shmId_create == -1) {
std::cerr << "Failed to create shared memory segment." << std::endl;
Expand All @@ -155,7 +163,7 @@ class Concore{
{
while (true) {
// Get the shared memory segment created by Writer
shmId_get = shmget(key, 256, 0666);
shmId_get = shmget(key, 256, 0600);
// Check if shared memory exists
if (shmId_get != -1) {
break; // Break the loop if shared memory exists
Expand Down Expand Up @@ -284,6 +292,9 @@ class Concore{
* @return a string of file content
*/
vector<double> read_FM(int port, string name, string initstr){
if (!isSafeName(name)) {
return vector<double>();
}
chrono::milliseconds timespan((int)(1000*delay));
this_thread::sleep_for(timespan);
string ins;
Expand Down Expand Up @@ -440,6 +451,9 @@ class Concore{
* @param delta The delta value (default: 0).
*/
void write_FM(int port, string name, vector<double> val, int delta=0){
if (!isSafeName(name)) {
return;
}

try {
ofstream outfile;
Expand Down Expand Up @@ -470,6 +484,9 @@ class Concore{
* @param delta The delta value (default: 0).
*/
void write_FM(int port, string name, string val, int delta=0){
if (!isSafeName(name)) {
return;
}
chrono::milliseconds timespan((int)(2000*delay));
this_thread::sleep_for(timespan);
try {
Expand Down
69 changes: 61 additions & 8 deletions concore.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ast import literal_eval
import sys
import re
import json
import zmq # Added for ZeroMQ

# if windows, create script to kill this process
Expand Down Expand Up @@ -67,6 +68,52 @@ def recv_json_with_retry(self):
# Global ZeroMQ ports registry
zmq_ports = {}

def _normalize_zmq_address(port_type, address):
if not isinstance(address, str):
return address
addr = address.strip()
if port_type == "bind" and addr.startswith("tcp://*:"):
if os.environ.get("CONCORE_ALLOW_WILDCARD_BIND", "").lower() not in ("1", "true", "yes"):
safe_addr = "tcp://127.0.0.1:" + addr.split(":")[-1]
print(f"Warning: Rewriting insecure bind address '{addr}' to '{safe_addr}'.")
return safe_addr
return addr

def _is_safe_name(name):
if not isinstance(name, str) or name == "":
return False
if os.path.basename(name) != name or "/" in name or "\\" in name:
return False
if name in (".", "..") or ".." in name:
return False
return True

def _safe_channel_path(base_path_prefix, port_identifier, name):
if not _is_safe_name(name):
raise ValueError(f"Unsafe channel name: {name}")
port_str = str(int(port_identifier))
channel_dir = os.path.abspath(base_path_prefix + port_str)
file_path = os.path.abspath(os.path.join(channel_dir, name))
if os.path.commonpath([channel_dir, file_path]) != channel_dir:
raise ValueError(f"Path traversal detected for channel name: {name}")
return file_path

def _parse_untrusted_value(raw_value, default_value):
if not isinstance(raw_value, str):
return raw_value
text = raw_value.strip()
if text == "":
return default_value
try:
return json.loads(text)
except Exception:
pass
try:
return literal_eval(text)
except Exception:
return default_value


def init_zmq_port(port_name, port_type, address, socket_type_str):
"""
Initializes and registers a ZeroMQ port.
Expand All @@ -82,8 +129,9 @@ def init_zmq_port(port_name, port_type, address, socket_type_str):
try:
# Map socket type string to actual ZMQ constant (e.g., zmq.REQ, zmq.REP)
zmq_socket_type = getattr(zmq, socket_type_str.upper())
zmq_ports[port_name] = ZeroMQPort(port_type, address, zmq_socket_type)
print(f"Initialized ZMQ port: {port_name} ({socket_type_str}) on {address}")
normalized_address = _normalize_zmq_address(port_type, address)
zmq_ports[port_name] = ZeroMQPort(port_type, normalized_address, zmq_socket_type)
print(f"Initialized ZMQ port: {port_name} ({socket_type_str}) on {normalized_address}")
except AttributeError:
print(f"Error: Invalid ZMQ socket type string '{socket_type_str}'.")
except zmq.error.ZMQError as e:
Expand All @@ -106,7 +154,7 @@ def terminate_zmq():
def safe_literal_eval(filename, defaultValue):
try:
with open(filename, "r") as file:
return literal_eval(file.read())
return _parse_untrusted_value(file.read(), defaultValue)
except (FileNotFoundError, SyntaxError, ValueError, Exception) as e:
# Keep print for debugging, but can be made quieter
# print(f"Info: Error reading {filename} or file not found, using default: {e}")
Expand Down Expand Up @@ -146,7 +194,8 @@ def safe_literal_eval(filename, defaultValue):
sparams = "{'"+re.sub(';',",'",re.sub('=',"':",re.sub(' ','',sparams)))+"}"
print("converted sparams: " + sparams)
try:
params = literal_eval(sparams)
parsed_params = _parse_untrusted_value(sparams, dict())
params = parsed_params if isinstance(parsed_params, dict) else dict()
except Exception as e:
print(f"bad params content: {sparams}, error: {e}")
params = dict()
Expand Down Expand Up @@ -219,8 +268,12 @@ def read(port_identifier, name, initstr_val):
print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.")
return default_return_val

time.sleep(delay)
file_path = os.path.join(inpath+str(file_port_num), name)
time.sleep(delay)
try:
file_path = _safe_channel_path(inpath, file_port_num, name)
except ValueError as e:
print(f"Error: {e}")
return default_return_val
ins = ""

try:
Expand Down Expand Up @@ -253,7 +306,7 @@ def read(port_identifier, name, initstr_val):

# Try parsing
try:
inval = literal_eval(ins)
inval = _parse_untrusted_value(ins, default_return_val)
if isinstance(inval, list) and len(inval) > 0:
current_simtime_from_file = inval[0]
if isinstance(current_simtime_from_file, (int, float)):
Expand Down Expand Up @@ -320,7 +373,7 @@ def initval(simtime_val_str):
"""
global simtime
try:
val = literal_eval(simtime_val_str)
val = _parse_untrusted_value(simtime_val_str, [])
if isinstance(val, list) and len(val) > 0:
first_element = val[0]
if isinstance(first_element, (int, float)):
Expand Down
5 changes: 4 additions & 1 deletion demo/pwrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@
except:
init_simtime_ym = "[0.0, 0.0, 0.0]"

print(apikey)
if apikey:
print('apikey loaded')
else:
print('apikey not found')
print(yuyu)
print(name1+'='+init_simtime_u)
print(name2+'='+init_simtime_ym)
Expand Down
5 changes: 4 additions & 1 deletion ratc/pwrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@
except:
init_simtime_ym = "[0.0, 0.0, 0.0]"

print(apikey)
if apikey:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a debug log. Ideally it shouldn't be there. But in any case, it was printing the actual apikey. Now you are printing a string. That makes the log meaningless.

print('apikey loaded')
else:
print('apikey not found')
print(yuyu)
print(name1+'='+init_simtime_u)
print(name2+'='+init_simtime_ym)
Expand Down