Published and fixed tests

This commit is contained in:
simonwt 2026-04-02 16:57:24 +01:00
parent 8af675d341
commit c6c7e7bb28
35 changed files with 15 additions and 6 deletions

View file

@ -0,0 +1 @@
from .apiclient import APIClient

View file

@ -0,0 +1,162 @@
import requests
import simplejson as json
from urllib.parse import urlencode
from pydantic import (
PrivateAttr, BaseModel, SecretStr, HttpUrl
)
endpoints = {
"alpha": {
"audrey": "https://eso1of8gqd.execute-api.us-east-1.amazonaws.com/alpha/",
"auth": "https://oo0wks9pbi.execute-api.us-east-1.amazonaws.com/alpha/",
"content": "https://w1yygdhayc.execute-api.us-east-1.amazonaws.com/alpha/",
"megaphone": "https://opdhjaktnl.execute-api.us-east-1.amazonaws.com/alpha/",
}
}
class APIClient(BaseModel):
# Internal State (Not passed in __init__)
## PrivateAttr to separate state from config and avoid Pylint errors
debug: bool = False
environment: str = "alpha"
client_id: SecretStr
client_secret: SecretStr
_access_token: str = PrivateAttr(default="")
_access_token_timeout: int = PrivateAttr(default=0)
def make_request(self, method: str, endpoint: str, path: str, data: dict = None, authenticate: bool = True, query_params: dict = None) -> dict:
# Make sure the endpoint is defined in the endpoints dictionary
if endpoint not in endpoints[self.environment]:
raise ValueError(f"Endpoint '{endpoint}' is not defined in the API client.")
# Make sure the HTTP method is valid
if method.upper() not in ["GET", "POST", "PUT", "DELETE"]:
raise ValueError(f"HTTP method '{method}' is not supported.")
method = method.lower()
# Construct the full URL for the API request
url = f"{endpoints[self.environment][endpoint]}{path}"
# Append query parameters to the URL if provided
if query_params is not None and isinstance(query_params, dict):
url += f"?{urlencode(query_params)}"
# Set up headers for the request
headers = {
"Content-Type": "application/json",
}
# Add the Authorization header with the access token if authentication is required
if authenticate:
token = self._access_token
headers["Authorization"] = f"Bearer {token}"
# Debugging output to show the request details before making the API call
if self.debug:
print({
"method": method.upper(),
"url": url,
"headers": headers,
"payload": data
})
# Make the API request and handle potential exceptions
try:
response = requests.request(method.upper(), url, json=data, headers=headers, timeout=20)
return_json = response.json()
except requests.exceptions.RequestException as e:
raise ConnectionError(f"An error occurred while making the request: {e}")
except json.JSONDecodeError as e:
raise ValueError(f"Response is not valid JSON: {e}")
return return_json
def sign_in(self,) -> dict:
"""
Authenticates with the API to obtain an access token.
Sends POST request to OAuth2 token endpoint with client_id and
client_secret to retrieve access token and its expiration time.
"""
# Endpoint for API requests access token for ~3 months
payload = { # Data package for endpoint to get the access token
"client_id": self.client_id.get_secret_value(),
"client_secret": self.client_secret.get_secret_value(),
"grant_type": "client_credentials"
}
token_data = self.make_request("POST", "auth", "token", payload, authenticate=False)
if 'tokenData' not in token_data:
raise ValueError(f"Unexpected response structure: {token_data}")
if self.debug:
print(f"Token data received: {token_data}") # Debug statement to check token data structure
# Access token and its timeout timestamp.
self._access_token = token_data["tokenData"]["accessToken"]
self._access_token_timeout = token_data["tokenData"]["accessTimeOut"]
return {
"access_token": self._access_token,
"access_token_timeout": self._access_token_timeout
}
def set_token(self, token_data: dict) -> None:
"""
Manually set the access token and its timeout.
Args:
token_data (dict): A dictionary containing the access token and its expiration time.
"""
self._access_token = token_data.get("access_token", "")
self._access_token_timeout = token_data.get("access_token_timeout", 0)
def is_token_valid(self) -> bool:
"""
Checks if the current access token is still valid based on the current time and the token's expiration time.
Returns:
bool: True if the token is valid, False otherwise.
"""
import time
current_time = int(time.time())
return self._access_token and current_time < self._access_token_timeout
def run_job(self, job_function, *args, **kwargs):
"""
Utility method to run a job function with the API client as the first argument.
Args:
job_function (string): The job function to execute.
*args: Positional arguments to pass to the job function.
**kwargs: Keyword arguments to pass to the job function.
Returns:
The result of the job function execution.
"""
if isinstance(job_function, str):
# Dynamically import the job function from the jobs module
module_name, func_name = job_function.rsplit('.', 1)
module = __import__(f"trustcafeapiwrapper.jobs.{module_name}", fromlist=[func_name])
job_func = getattr(module, func_name)
else:
job_func = job_function
return job_func(self, *args, **kwargs)
def wrapped(self, wrapped_data):
"""
Utility method to run a job function based on a wrapped data dictionary
containing 'job' and 'payload' keys as expected by the API client wrapper functions.
Args:
wrapped_data (dict): A dictionary with 'job' (string) and 'payload' (dict) keys.
Returns:
The result of the job function execution.
"""
return self.run_job(wrapped_data.get("job_function"), wrapped_data.get("payload", {}))

View file

View file

@ -0,0 +1,2 @@
from .get import get
from .listbyname import listbyname

View file

@ -0,0 +1,11 @@
def get(API, branch_slug: str,) -> dict:
"""
Fetches the branch/subwiki data from the API.
Args:
branch_slug (str): Slug of the user whose branch/subwiki to fetch.
Returns:
dict: The branch/subwiki data.
"""
branch_data = API.make_request("GET", "content", f"subwiki/{branch_slug}", authenticate=True)
return branch_data

View file

@ -0,0 +1,9 @@
def listbyname(API, lastEvaluatedKey=None) -> dict:
"""
Fetches a list of branchs from the API by name.
Returns:
dict: The list of comments for the post.
"""
branch_list = API.make_request("GET", "content", f"subwiki", authenticate=True, query_params=lastEvaluatedKey)
return branch_list

View file

@ -0,0 +1,2 @@
from .create import create
from .listtbypostid import listtbypostid

View file

@ -0,0 +1,12 @@
def create(API, payload: dict) -> dict:
"""
Creates a new comment in the API.
Args:
payload (dict): The data for the new post.
Returns:
dict: The comment data.
"""
comment_data = API.make_request("POST", "content", "commentcreate", data=payload, authenticate=True)
return comment_data

View file

@ -0,0 +1,11 @@
def listtbypostid(API, post_id: str,) -> dict:
"""
Fetches the list of comments for a given post ID from the API.
Args:
post_id (str): ID of the post to fetch comments for.
Returns:
dict: The list of comments for the post.
"""
comment_list = API.make_request("GET", "content", f"comment/bypostid/{post_id}", authenticate=True)
return comment_list

View file

@ -0,0 +1,2 @@
from .cafefeed import cafefeed
from .following import followingfeed

View file

@ -0,0 +1,9 @@
def cafefeed(API, lastEvaluatedKey=None):
"""
List all of a token's user's feed items from their Cafe Feed
Returns:
A list of feed items.
"""
feed = API.make_request("GET", "audrey", "feed/foryou", authenticate=True, query_params=lastEvaluatedKey)
return feed

View file

@ -0,0 +1,11 @@
def followingfeed(API, lastEvaluatedKey=None):
"""
List all of a token's user's feed items from the users and branches they are following.
Returns:
A list of feed items.
"""
feed = API.make_request("GET", "audrey", "feed/following", authenticate=True, query_params=lastEvaluatedKey)
return feed

View file

@ -0,0 +1 @@
from .listnotifications import listnotifications

View file

@ -0,0 +1,9 @@
def listnotifications(API):
"""
List all of a token's user's notifications.
Returns:
A list of notifications.
"""
notifications = API.make_request("GET", "megaphone", "inbox/notifications", authenticate=True)
return notifications

View file

@ -0,0 +1,6 @@
from .get import get
from .listbybranch import listbybranch
from .listbyuserprofile import listbyuserprofile
from .create import create
from .listall import listall
from .listpublic import listpublic

View file

@ -0,0 +1,12 @@
def create(API, payload: dict) -> dict:
"""
Creates a new post in the API.
Args:
payload (dict): The data for the new post.
Returns:
dict: The post data.
"""
post_data = API.make_request("POST", "content", "post", data=payload, authenticate=True)
return post_data

View file

@ -0,0 +1,12 @@
def get(API, post_slug: str,) -> dict:
"""
Fetches the post data from the API.
Args:
post_slug (str): Slug of the post to fetch data for.
Returns:
dict: The post data.
"""
post_data = API.make_request("GET", "content", f"post/id/{post_slug}", authenticate=True)
return post_data

View file

@ -0,0 +1,11 @@
def listall(API) -> dict:
"""
Fetches the list of posts for a given branch from the API.
Args:
branch_slug (str): Slug of the branch to fetch posts for.
Returns:
dict: The list of posts for the branch.
"""
post_list = API.make_request("GET", "content", f"post", authenticate=True)
return post_list

View file

@ -0,0 +1,11 @@
def listpublic(API) -> dict:
"""
Fetches the list of public posts from the API.
Args:
branch_slug (str): Slug of the branch to fetch posts for.
Returns:
dict: The list of public posts.
"""
post_list = API.make_request("GET", "content", f"post/public", authenticate=True)
return post_list

View file

@ -0,0 +1,11 @@
def listbybranch(API, branch_slug: str,) -> dict:
"""
Fetches the list of posts for a given branch from the API.
Args:
branch_slug (str): Slug of the branch to fetch posts for.
Returns:
dict: The list of posts for the branch.
"""
post_list = API.make_request("GET", "content", f"post/ref-subwiki/{branch_slug}", authenticate=True)
return post_list

View file

@ -0,0 +1,13 @@
def listbyuserprofile(API, user_slug: str,) -> dict:
"""
Fetches the list of posts for a given user profile from the API.
Args:
user_slug (str): Slug of the user profile to fetch posts for.
Returns:
dict: The list of posts for the user profile.
"""
# Note there is actually a reference to `/branch` in the API url
# and it should be considered for removal because that's confusing
post_list = API.make_request("GET", "content", f"post/ref-userprofile/branch/{user_slug}", authenticate=True)
return post_list

View file

@ -0,0 +1,11 @@
def listpublic(API) -> dict:
"""
Fetches the list of public posts from the API.
Args:
branch_slug (str): Slug of the branch to fetch posts for.
Returns:
dict: The list of public posts.
"""
post_list = API.make_request("GET", "content", f"post/public", authenticate=True)
return post_list

View file

@ -0,0 +1,2 @@
from .get import get
from .get import get

View file

@ -0,0 +1,11 @@
def get(API, user_slug: str,) -> dict:
"""
Fetches the user profile data from the API.
Args:
user_slug (str): Slug of the user whose profile to fetch.
Returns:
dict: The user profile data.
"""
profile_data = API.make_request("GET", "content", f"userprofile/{user_slug}", authenticate=True)
return profile_data

View file

@ -0,0 +1,2 @@
from .get_parent_pksk_from_path import get_parent_pksk_from_path
from .get_post_pksk import get_post_pksk

View file

@ -0,0 +1,15 @@
def get_parent_pksk_from_path(parent_path):
if parent_path == '/':
return 'maintrunk#maintrunk'
entity, slug = parent_path.strip('/').split('/')
if entity == 'branch':
entity = 'subwiki'
elif entity == 'user':
entity = 'userprofile'
if entity not in ['userprofile', 'subwiki']:
raise ValueError(f"Invalid parent entity: {entity}. Must be 'userprofile' or 'subwiki'.")
return f"{entity}#{slug}"

View file

@ -0,0 +1,9 @@
def get_post_pksk(parent_pksk, post_url):
post_slug = post_url.strip('/post/')
return {
"pk": parent_pksk,
"sk": f"post#{post_slug}"
}

View file

@ -0,0 +1 @@
from .create_comment import create_comment

View file

@ -0,0 +1,37 @@
from trustcafeapiwrapper.utils.get_parent_pksk_from_path import get_parent_pksk_from_path
from trustcafeapiwrapper.utils.get_post_pksk import get_post_pksk
def create_comment(comment_text, post_slug, parent_path, blur_label=None, version=3):
"""
Creates a new comment.
Args:
comment_text (str): The text content of the comment.
parent_path (str): The parent path for the comment, in the format 'userprofile/slug' or 'subwiki/slug'.
blur_label (str, optional): An optional label for blurring the comment content.
version (int, optional): The version of the comment structure to use, default is 3.
Returns:
dict: A dictionary containing the job name and payload for creating the comment
that will be processed by the API client wrapper function.
"""
parent_pksk = get_parent_pksk_from_path(parent_path)
post_pksk = get_post_pksk(parent_pksk, post_slug)
return {
"job_function": "comment.create",
"payload": {
"blurLabel": blur_label,
"commentText": comment_text,
"parent": {
"pk": post_pksk['pk'],
"sk": post_pksk['sk'],
"slug": parent_path.split('/')[-1]
},
"topLevel": {
"pk": post_pksk['pk'],
"sk": post_pksk['sk']
},
"version": version
}
}

View file

@ -0,0 +1 @@
from .create_post import create_post

View file

@ -0,0 +1,34 @@
from trustcafeapiwrapper.utils.get_parent_pksk_from_path import get_parent_pksk_from_path
def create_post(post_text, parent_path='/', blur_label=None, card_url=None, collaborative=False):
"""
Creates a new post.
Args:
post_text (str): The text content of the post.
parent_path (str, optional): The parent path for the post, default is None.
blur_label (str, optional): An optional label for blurring the post content.
card_url (str, optional): An optional URL to include as a card in the post.
collaborative (bool, optional): Whether the post is collaborative, default is False.
Returns:
dict: A dictionary containing the job name and payload for creating the post
that will be processed by the API client wrapper function.
"""
parent_pksk = get_parent_pksk_from_path(parent_path)
return {
"job_function": "post.create",
"payload": {
"blurLabel": blur_label,
"cardUrl": card_url,
"postText": post_text,
"collaborative": collaborative,
"parent": {
"pk": parent_pksk,
"sk": parent_pksk
}
}
}