Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions src/basic_memory/cli/commands/cloud/cloud_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,23 @@ class CloudUtilsError(Exception):
pass


def _workspace_headers(workspace: str | None = None) -> dict[str, str]:
"""Build workspace header if workspace is specified."""
if workspace:
return {"X-Workspace-ID": workspace}
return {}


async def fetch_cloud_projects(
*,
workspace: str | None = None,
api_request=make_api_request,
) -> CloudProjectList:
"""Fetch list of projects from cloud API.

Args:
workspace: Cloud workspace tenant_id to list projects from

Returns:
CloudProjectList with projects from cloud
"""
Expand All @@ -30,7 +41,11 @@ async def fetch_cloud_projects(
config = config_manager.config
host_url = config.cloud_host.rstrip("/")

response = await api_request(method="GET", url=f"{host_url}/proxy/v2/projects/")
response = await api_request(
method="GET",
url=f"{host_url}/proxy/v2/projects/",
headers=_workspace_headers(workspace),
)

return CloudProjectList.model_validate(response.json())
except Exception as e:
Expand All @@ -40,12 +55,14 @@ async def fetch_cloud_projects(
async def create_cloud_project(
project_name: str,
*,
workspace: str | None = None,
api_request=make_api_request,
) -> CloudProjectCreateResponse:
"""Create a new project on cloud.

Args:
project_name: Name of project to create
workspace: Cloud workspace tenant_id to create project in

Returns:
CloudProjectCreateResponse with project details from API
Expand All @@ -64,10 +81,13 @@ async def create_cloud_project(
set_default=False,
)

headers = {"Content-Type": "application/json"}
headers.update(_workspace_headers(workspace))

response = await api_request(
method="POST",
url=f"{host_url}/proxy/v2/projects/",
headers={"Content-Type": "application/json"},
headers=headers,
json_data=project_data.model_dump(),
)

Expand All @@ -91,17 +111,23 @@ async def sync_project(project_name: str, force_full: bool = False) -> None:
raise CloudUtilsError(f"Failed to sync project '{project_name}': {e}") from e


async def project_exists(project_name: str, *, api_request=make_api_request) -> bool:
async def project_exists(
project_name: str,
*,
workspace: str | None = None,
api_request=make_api_request,
) -> bool:
"""Check if a project exists on cloud.

Args:
project_name: Name of project to check
workspace: Cloud workspace tenant_id to check in

Returns:
True if project exists, False otherwise
"""
try:
projects = await fetch_cloud_projects(api_request=api_request)
projects = await fetch_cloud_projects(workspace=workspace, api_request=api_request)
project_names = {p.name for p in projects.projects}
return project_name in project_names
except Exception:
Expand Down
21 changes: 18 additions & 3 deletions src/basic_memory/cli/commands/cloud/upload_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
sync_project,
)
from basic_memory.cli.commands.cloud.upload import upload_path
from basic_memory.config import ConfigManager
from basic_memory.mcp.async_client import get_cloud_control_plane_client

console = Console()
Expand Down Expand Up @@ -73,12 +74,21 @@ def upload(
"""

async def _upload():
# Resolve workspace: per-project workspace_id, then default_workspace
config = ConfigManager().config
workspace = None
entry = config.projects.get(project)
if entry and entry.workspace_id:
workspace = entry.workspace_id
elif config.default_workspace:
workspace = config.default_workspace

# Check if project exists
if not await project_exists(project):
if not await project_exists(project, workspace=workspace):
if create_project:
console.print(f"[blue]Creating cloud project '{project}'...[/blue]")
try:
await create_cloud_project(project)
await create_cloud_project(project, workspace=workspace)
console.print(f"[green]Created project '{project}'[/green]")
except Exception as e:
console.print(f"[red]Failed to create project: {e}[/red]")
Expand All @@ -93,20 +103,25 @@ async def _upload():
raise typer.Exit(1)

# Perform upload (or dry run)
if workspace:
console.print(f"[dim]Using workspace: {workspace}[/dim]")
if dry_run:
console.print(
f"[yellow]DRY RUN: Showing what would be uploaded to '{project}'[/yellow]"
)
else:
console.print(f"[blue]Uploading {path} to project '{project}'...[/blue]")

def _client_factory():
return get_cloud_control_plane_client(workspace=workspace)

success = await upload_path(
path,
project,
verbose=verbose,
use_gitignore=not no_gitignore,
dry_run=dry_run,
client_cm_factory=get_cloud_control_plane_client,
client_cm_factory=_client_factory,
)
if not success:
console.print("[red]Upload failed[/red]")
Expand Down
19 changes: 16 additions & 3 deletions src/basic_memory/mcp/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,20 @@ async def _cloud_client(


@asynccontextmanager
async def get_cloud_control_plane_client() -> AsyncIterator[AsyncClient]:
async def get_cloud_control_plane_client(
workspace: Optional[str] = None,
) -> AsyncIterator[AsyncClient]:
"""Create a control-plane cloud client for endpoints outside /proxy."""
config = ConfigManager().config
timeout = _build_timeout()
token = await _resolve_cloud_token(config)
headers: dict[str, str] = {"Authorization": f"Bearer {token}"}
if workspace:
headers["X-Workspace-ID"] = workspace
logger.info(f"Creating HTTP client for cloud control plane at: {config.cloud_host}")
async with AsyncClient(
base_url=config.cloud_host,
headers={"Authorization": f"Bearer {token}"},
headers=headers,
timeout=timeout,
) as client:
yield client
Expand Down Expand Up @@ -179,8 +184,16 @@ async def get_client(
project_mode = config.get_project_mode(project_name)
if project_mode == ProjectMode.CLOUD:
logger.debug(f"Project '{project_name}' is cloud mode - using cloud proxy client")
# Resolve workspace from project config if not explicitly provided
effective_workspace = workspace
if effective_workspace is None:
entry = config.projects.get(project_name)
if entry and entry.workspace_id:
effective_workspace = entry.workspace_id
elif config.default_workspace:
effective_workspace = config.default_workspace
try:
async with _cloud_client(config, timeout, workspace=workspace) as client:
async with _cloud_client(config, timeout, workspace=effective_workspace) as client:
yield client
except RuntimeError as exc:
raise RuntimeError(
Expand Down
76 changes: 76 additions & 0 deletions tests/cli/cloud/test_cloud_api_client_and_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
make_api_request,
)
from basic_memory.cli.commands.cloud.cloud_utils import (
_workspace_headers,
create_cloud_project,
fetch_cloud_projects,
project_exists,
Expand Down Expand Up @@ -165,6 +166,81 @@ async def api_request(**kwargs):
assert seen["create_payload"]["path"] == "my-project"


def test_workspace_headers_with_workspace():
"""_workspace_headers returns X-Workspace-ID when workspace is provided."""
assert _workspace_headers("tenant-123") == {"X-Workspace-ID": "tenant-123"}


def test_workspace_headers_without_workspace():
"""_workspace_headers returns empty dict when no workspace."""
assert _workspace_headers(None) == {}
assert _workspace_headers() == {}


@pytest.mark.asyncio
async def test_cloud_utils_pass_workspace_header(config_home, config_manager):
"""fetch_cloud_projects, project_exists, and create_cloud_project pass workspace header."""
config = config_manager.load_config()
config.cloud_host = "https://cloud.example.test"
config_manager.save_config(config)

auth = CLIAuth(client_id="cid", authkit_domain="https://auth.example.test")
auth.token_file.parent.mkdir(parents=True, exist_ok=True)
auth.token_file.write_text(
'{"access_token":"token-123","refresh_token":null,"expires_at":9999999999,"token_type":"Bearer"}',
encoding="utf-8",
)

seen_headers: list[dict] = []

async def handler(request: httpx.Request) -> httpx.Response:
seen_headers.append(dict(request.headers))
if request.method == "GET":
return httpx.Response(200, json={"projects": []})
if request.method == "POST":
payload = json.loads(request.content.decode("utf-8"))
return httpx.Response(
200,
json={
"message": "created",
"status": "success",
"default": False,
"old_project": None,
"new_project": {"name": payload["name"], "path": payload["path"]},
},
)
raise AssertionError(f"Unexpected: {request.method}")

transport = httpx.MockTransport(handler)

@asynccontextmanager
async def http_client_factory():
async with httpx.AsyncClient(
transport=transport, base_url="https://cloud.example.test"
) as client:
yield client

async def api_request(**kwargs):
return await make_api_request(auth=auth, http_client_factory=http_client_factory, **kwargs)

# fetch with workspace
await fetch_cloud_projects(workspace="tenant-abc", api_request=api_request)
assert seen_headers[-1].get("x-workspace-id") == "tenant-abc"

# project_exists with workspace
await project_exists("test", workspace="tenant-abc", api_request=api_request)
assert seen_headers[-1].get("x-workspace-id") == "tenant-abc"

# create with workspace
await create_cloud_project("new-proj", workspace="tenant-abc", api_request=api_request)
assert seen_headers[-1].get("x-workspace-id") == "tenant-abc"

# Without workspace — header should not be present
seen_headers.clear()
await fetch_cloud_projects(api_request=api_request)
assert "x-workspace-id" not in seen_headers[-1]


@pytest.mark.asyncio
async def test_make_api_request_prefers_api_key_over_oauth(config_home, config_manager):
"""API key in config should be used without needing an OAuth token on disk."""
Expand Down
6 changes: 3 additions & 3 deletions tests/cli/cloud/test_upload_command_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
runner = CliRunner()


def test_cloud_upload_uses_control_plane_client(monkeypatch, tmp_path):
def test_cloud_upload_uses_control_plane_client(monkeypatch, tmp_path, config_manager):
"""Upload command should use control-plane cloud client for WebDAV PUT operations."""
import basic_memory.cli.commands.cloud.upload_command as upload_command

Expand All @@ -20,11 +20,11 @@ def test_cloud_upload_uses_control_plane_client(monkeypatch, tmp_path):

seen: dict[str, str] = {}

async def fake_project_exists(_project_name: str) -> bool:
async def fake_project_exists(_project_name: str, workspace: str | None = None) -> bool:
return True

@asynccontextmanager
async def fake_get_client():
async def fake_get_client(workspace=None):
async with httpx.AsyncClient(base_url="https://cloud.example.test") as client:
yield client

Expand Down
58 changes: 58 additions & 0 deletions tests/mcp/test_async_client_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,64 @@ async def test_get_cloud_control_plane_client_uses_oauth_token(config_manager):
assert client.headers.get("Authorization") == "Bearer oauth-control-123"


@pytest.mark.asyncio
async def test_get_cloud_control_plane_client_with_workspace(config_manager):
"""Control plane client passes X-Workspace-ID header when workspace is provided."""
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = "bmc_test_key_123"
config_manager.save_config(cfg)

async with get_cloud_control_plane_client(workspace="tenant-abc") as client:
assert client.headers.get("X-Workspace-ID") == "tenant-abc"

# Without workspace, header should not be present
async with get_cloud_control_plane_client() as client:
assert "X-Workspace-ID" not in client.headers


@pytest.mark.asyncio
async def test_get_client_auto_resolves_workspace_from_project_config(config_manager):
"""get_client resolves workspace from project entry when not explicitly passed."""
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = "bmc_test_key_123"
cfg.set_project_mode("research", ProjectMode.CLOUD)
cfg.projects["research"].workspace_id = "tenant-from-config"
config_manager.save_config(cfg)

async with get_client(project_name="research") as client:
assert client.headers.get("X-Workspace-ID") == "tenant-from-config"


@pytest.mark.asyncio
async def test_get_client_auto_resolves_workspace_from_default(config_manager):
"""get_client falls back to default_workspace when project has no workspace_id."""
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = "bmc_test_key_123"
cfg.set_project_mode("research", ProjectMode.CLOUD)
cfg.default_workspace = "default-tenant-456"
config_manager.save_config(cfg)

async with get_client(project_name="research") as client:
assert client.headers.get("X-Workspace-ID") == "default-tenant-456"


@pytest.mark.asyncio
async def test_get_client_explicit_workspace_overrides_config(config_manager):
"""Explicit workspace param takes priority over project config."""
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = "bmc_test_key_123"
cfg.set_project_mode("research", ProjectMode.CLOUD)
cfg.projects["research"].workspace_id = "tenant-from-config"
config_manager.save_config(cfg)

async with get_client(project_name="research", workspace="explicit-tenant") as client:
assert client.headers.get("X-Workspace-ID") == "explicit-tenant"


@pytest.mark.asyncio
async def test_get_cloud_control_plane_client_raises_without_credentials(config_manager):
cfg = config_manager.load_config()
Expand Down
Loading